Mahiruoshi commited on
Commit
8e84ca1
1 Parent(s): cb3cd10

Update server.py

Browse files
Files changed (1) hide show
  1. server.py +207 -670
server.py CHANGED
@@ -3,7 +3,6 @@ import os
3
  from pathlib import Path
4
 
5
  import logging
6
- import uuid
7
  import re_matching
8
 
9
  logging.getLogger("numba").setLevel(logging.WARNING)
@@ -16,8 +15,7 @@ logging.basicConfig(
16
  )
17
 
18
  logger = logging.getLogger(__name__)
19
- import shutil
20
- from scipy.io.wavfile import write
21
  import librosa
22
  import numpy as np
23
  import torch
@@ -25,6 +23,11 @@ import torch.nn as nn
25
  from torch.utils.data import Dataset
26
  from torch.utils.data import DataLoader, Dataset
27
  from tqdm import tqdm
 
 
 
 
 
28
 
29
  import gradio as gr
30
 
@@ -40,28 +43,9 @@ import utils
40
  from models import SynthesizerTrn
41
  from text.symbols import symbols
42
  import sys
43
- import re
44
-
45
- import random
46
- import hashlib
47
-
48
- from fugashi import Tagger
49
- import jaconv
50
- import unidic
51
- import subprocess
52
-
53
- import requests
54
-
55
- from ebooklib import epub
56
- import PyPDF2
57
- from PyPDF2 import PdfReader
58
- from bs4 import BeautifulSoup
59
- import jieba
60
- import romajitable
61
-
62
- from flask import Flask, request, jsonify, render_template_string, send_file
63
- from flask_cors import CORS
64
  from scipy.io.wavfile import write
 
 
65
  net_g = None
66
 
67
  device = (
@@ -91,359 +75,6 @@ BandList = {
91
  "西克菲尔特音乐学院":["晶","未知留","八千代","栞","美帆"]
92
  }
93
 
94
- webBase = 'https://mahiruoshi-bangdream-bert-vits2.hf.space/'
95
-
96
- port = 8080
97
-
98
- languages = [ "Auto", "ZH", "JP"]
99
- modelPaths = []
100
- modes = ['pyopenjtalk-V2.3-Katakana','fugashi-V2.3-Katakana','pyopenjtalk-V2.3-Katakana-Katakana','fugashi-V2.3-Katakana-Katakana','onnx-V2.3']
101
- sentence_modes = ['sentence','paragraph']
102
- for dirpath, dirnames, filenames in os.walk('Data/BangDream/models/'):
103
- for filename in filenames:
104
- modelPaths.append(os.path.join(dirpath, filename))
105
- hps = utils.get_hparams_from_file('Data/BangDream/config.json')
106
-
107
- def translate(Sentence: str, to_Language: str = "jp", from_Language: str = ""):
108
- """
109
- :param Sentence: 待翻译语句
110
- :param from_Language: 待翻译语句语言
111
- :param to_Language: 目标语言
112
- :return: 翻译后语句 出错时返回None
113
-
114
- 常见语言代码:中文 zh 英语 en 日语 jp
115
- """
116
- appid = "20231117001883321"
117
- key = "lMQbvZHeJveDceLof2wf"
118
- if appid == "" or key == "":
119
- return "请开发者在config.yml中配置app_key与secret_key"
120
- url = "https://fanyi-api.baidu.com/api/trans/vip/translate"
121
- texts = Sentence.splitlines()
122
- outTexts = []
123
- for t in texts:
124
- if t != "":
125
- # 签名计算 参考文档 https://api.fanyi.baidu.com/product/113
126
- salt = str(random.randint(1, 100000))
127
- signString = appid + t + salt + key
128
- hs = hashlib.md5()
129
- hs.update(signString.encode("utf-8"))
130
- signString = hs.hexdigest()
131
- if from_Language == "":
132
- from_Language = "auto"
133
- headers = {"Content-Type": "application/x-www-form-urlencoded"}
134
- payload = {
135
- "q": t,
136
- "from": from_Language,
137
- "to": to_Language,
138
- "appid": appid,
139
- "salt": salt,
140
- "sign": signString,
141
- }
142
- # 发送请求
143
- try:
144
- response = requests.post(
145
- url=url, data=payload, headers=headers, timeout=3
146
- )
147
- response = response.json()
148
- if "trans_result" in response.keys():
149
- result = response["trans_result"][0]
150
- if "dst" in result.keys():
151
- dst = result["dst"]
152
- outTexts.append(dst)
153
- except Exception:
154
- return Sentence
155
- else:
156
- outTexts.append(t)
157
- return "\n".join(outTexts)
158
-
159
- #文本清洗工具
160
- def is_japanese(string):
161
- for ch in string:
162
- if ord(ch) > 0x3040 and ord(ch) < 0x30FF:
163
- return True
164
- return False
165
-
166
- def is_chinese(string):
167
- for ch in string:
168
- if '\u4e00' <= ch <= '\u9fff':
169
- return True
170
- return False
171
-
172
- def is_single_language(sentence):
173
- # 检查句子是否为单一语言
174
- contains_chinese = re.search(r'[\u4e00-\u9fff]', sentence) is not None
175
- contains_japanese = re.search(r'[\u3040-\u30ff\u31f0-\u31ff]', sentence) is not None
176
- contains_english = re.search(r'[a-zA-Z]', sentence) is not None
177
- language_count = sum([contains_chinese, contains_japanese, contains_english])
178
- return language_count == 1
179
-
180
- def merge_scattered_parts(sentences):
181
- """合并零散的部分到相邻的句子中,并确保单一语言性"""
182
- merged_sentences = []
183
- buffer_sentence = ""
184
-
185
- for sentence in sentences:
186
- # 检查是否是单一语言或者太短(可能是标点或单个词)
187
- if is_single_language(sentence) and len(sentence) > 1:
188
- # 如果缓冲区有内容,先将缓冲���的内容添加到列表
189
- if buffer_sentence:
190
- merged_sentences.append(buffer_sentence)
191
- buffer_sentence = ""
192
- merged_sentences.append(sentence)
193
- else:
194
- # 如果是零散的部分,将其添加到缓冲区
195
- buffer_sentence += sentence
196
-
197
- # 确保最后的缓冲区内容被添加
198
- if buffer_sentence:
199
- merged_sentences.append(buffer_sentence)
200
-
201
- return merged_sentences
202
-
203
- def is_only_punctuation(s):
204
- """检查字符串是否只包含标点符号"""
205
- # 此处列出中文、日文、英文常见标点符号
206
- punctuation_pattern = re.compile(r'^[\s。*;,:“”()、!?《》\u3000\.,;:"\'?!()]+$')
207
- return punctuation_pattern.match(s) is not None
208
-
209
- def split_mixed_language(sentence):
210
- # 分割混合语言句子
211
- # 逐字符检查,分割不同语言部分
212
- sub_sentences = []
213
- current_language = None
214
- current_part = ""
215
-
216
- for char in sentence:
217
- if re.match(r'[\u4e00-\u9fff]', char): # Chinese character
218
- if current_language != 'chinese':
219
- if current_part:
220
- sub_sentences.append(current_part)
221
- current_part = char
222
- current_language = 'chinese'
223
- else:
224
- current_part += char
225
- elif re.match(r'[\u3040-\u30ff\u31f0-\u31ff]', char): # Japanese character
226
- if current_language != 'japanese':
227
- if current_part:
228
- sub_sentences.append(current_part)
229
- current_part = char
230
- current_language = 'japanese'
231
- else:
232
- current_part += char
233
- elif re.match(r'[a-zA-Z]', char): # English character
234
- if current_language != 'english':
235
- if current_part:
236
- sub_sentences.append(current_part)
237
- current_part = char
238
- current_language = 'english'
239
- else:
240
- current_part += char
241
- else:
242
- current_part += char # For punctuation and other characters
243
-
244
- if current_part:
245
- sub_sentences.append(current_part)
246
-
247
- return sub_sentences
248
-
249
- def replace_quotes(text):
250
- # 替换中文、日文引号为英文引号
251
- text = re.sub(r'[“”‘’『』「」()()]', '"', text)
252
- return text
253
-
254
- def remove_numeric_annotations(text):
255
- # 定义用于匹配数字注释的正则表达式
256
- # 包括 “”、【】和〔〕包裹的数字
257
- pattern = r'“\d+”|【\d+】|〔\d+〕'
258
- # 使用正则表达式替换掉这些注释
259
- cleaned_text = re.sub(pattern, '', text)
260
- return cleaned_text
261
-
262
- def merge_adjacent_japanese(sentences):
263
- """合并相邻且都只包含日语的句子"""
264
- merged_sentences = []
265
- i = 0
266
- while i < len(sentences):
267
- current_sentence = sentences[i]
268
- if i + 1 < len(sentences) and is_japanese(current_sentence) and is_japanese(sentences[i + 1]):
269
- # 当前句子和下一句都是日语,合并它们
270
- while i + 1 < len(sentences) and is_japanese(sentences[i + 1]):
271
- current_sentence += sentences[i + 1]
272
- i += 1
273
- merged_sentences.append(current_sentence)
274
- i += 1
275
- return merged_sentences
276
-
277
- def extrac(text):
278
- text = replace_quotes(remove_numeric_annotations(text)) # 替换引号
279
- text = re.sub("<[^>]*>", "", text) # 移除 HTML 标签
280
- # 使用换行符和标点符号进行初步分割
281
- preliminary_sentences = re.split(r'([\n。;!?\.\?!])', text)
282
- final_sentences = []
283
-
284
- preliminary_sentences = re.split(r'([\n。;!?\.\?!])', text)
285
-
286
- for piece in preliminary_sentences:
287
- if is_single_language(piece):
288
- final_sentences.append(piece)
289
- else:
290
- sub_sentences = split_mixed_language(piece)
291
- final_sentences.extend(sub_sentences)
292
-
293
- # 处理长句子,使用jieba进行分词
294
- split_sentences = []
295
- for sentence in final_sentences:
296
- split_sentences.extend(split_long_sentences(sentence))
297
-
298
- # 合并相邻的日语句子
299
- merged_japanese_sentences = merge_adjacent_japanese(split_sentences)
300
-
301
- # 剔除只包含标点符号的元素
302
- clean_sentences = [s for s in merged_japanese_sentences if not is_only_punctuation(s)]
303
-
304
- # 移除空字符串并去除多余引号
305
- return [s.replace('"','').strip() for s in clean_sentences if s]
306
-
307
-
308
-
309
- # 移除空字符串
310
-
311
- def is_mixed_language(sentence):
312
- contains_chinese = re.search(r'[\u4e00-\u9fff]', sentence) is not None
313
- contains_japanese = re.search(r'[\u3040-\u30ff\u31f0-\u31ff]', sentence) is not None
314
- contains_english = re.search(r'[a-zA-Z]', sentence) is not None
315
- languages_count = sum([contains_chinese, contains_japanese, contains_english])
316
- return languages_count > 1
317
-
318
- def split_mixed_language(sentence):
319
- # 分割混合语言句子
320
- sub_sentences = re.split(r'(?<=[。!?\.\?!])(?=")|(?<=")(?=[\u4e00-\u9fff\u3040-\u30ff\u31f0-\u31ff]|[a-zA-Z])', sentence)
321
- return [s.strip() for s in sub_sentences if s.strip()]
322
-
323
- def seconds_to_ass_time(seconds):
324
- """将秒数转换为ASS时间格式"""
325
- hours = int(seconds / 3600)
326
- minutes = int((seconds % 3600) / 60)
327
- seconds = int(seconds) % 60
328
- milliseconds = int((seconds - int(seconds)) * 1000)
329
- return "{:01d}:{:02d}:{:02d}.{:02d}".format(hours, minutes, seconds, int(milliseconds / 10))
330
-
331
- def extract_text_from_epub(file_path):
332
- book = epub.read_epub(file_path)
333
- content = []
334
- for item in book.items:
335
- if isinstance(item, epub.EpubHtml):
336
- soup = BeautifulSoup(item.content, 'html.parser')
337
- content.append(soup.get_text())
338
- return '\n'.join(content)
339
-
340
- def extract_text_from_pdf(file_path):
341
- with open(file_path, 'rb') as file:
342
- reader = PdfReader(file)
343
- content = [page.extract_text() for page in reader.pages]
344
- return '\n'.join(content)
345
-
346
- def remove_annotations(text):
347
- # 移除方括号、尖括号和中文方括号中的内容
348
- text = re.sub(r'\[.*?\]', '', text)
349
- text = re.sub(r'\<.*?\>', '', text)
350
- text = re.sub(r'&#8203;``【oaicite:1】``&#8203;', '', text)
351
- return text
352
-
353
- def extract_text_from_file(inputFile):
354
- file_extension = os.path.splitext(inputFile)[1].lower()
355
- if file_extension == ".epub":
356
- return extract_text_from_epub(inputFile)
357
- elif file_extension == ".pdf":
358
- return extract_text_from_pdf(inputFile)
359
- elif file_extension == ".txt":
360
- with open(inputFile, 'r', encoding='utf-8') as f:
361
- return f.read()
362
- else:
363
- raise ValueError(f"Unsupported file format: {file_extension}")
364
-
365
- def split_by_punctuation(sentence):
366
- """按照中文次级标点符号分割句子"""
367
- # 常见的中文次级分隔符号:逗号、分号等
368
- parts = re.split(r'([,,;;])', sentence)
369
- # 将标点符号与前面的词语合并,避免单独标点符号成为一个部分
370
- merged_parts = []
371
- for part in parts:
372
- if part and not part in ',,;;':
373
- merged_parts.append(part)
374
- elif merged_parts:
375
- merged_parts[-1] += part
376
- return merged_parts
377
-
378
- def split_long_sentences(sentence, max_length=30):
379
- """如果中文句子太长,先按标点分割,必要时使用jieba进行分词并分割"""
380
- if len(sentence) > max_length and is_chinese(sentence):
381
- # 首先尝试按照次级标点符号分割
382
- preliminary_parts = split_by_punctuation(sentence)
383
- new_sentences = []
384
-
385
- for part in preliminary_parts:
386
- # 如果部分仍然太长,使用jieba进行分词
387
- if len(part) > max_length:
388
- words = jieba.lcut(part)
389
- current_sentence = ""
390
- for word in words:
391
- if len(current_sentence) + len(word) > max_length:
392
- new_sentences.append(current_sentence)
393
- current_sentence = word
394
- else:
395
- current_sentence += word
396
- if current_sentence:
397
- new_sentences.append(current_sentence)
398
- else:
399
- new_sentences.append(part)
400
-
401
- return new_sentences
402
- return [sentence] # 如果句子不长或不是中文,直接返回
403
-
404
- def extract_and_convert(text):
405
-
406
- # 使用正则表达式找出所有英文单词
407
- english_parts = re.findall(r'\b[A-Za-z]+\b', text) # \b为单词边界标识
408
-
409
- # 对每个英文单词进行片假名转换
410
- kana_parts = ['\n{}\n'.format(romajitable.to_kana(word).katakana) for word in english_parts]
411
-
412
- # 替换原文本中的英文部分
413
- for eng, kana in zip(english_parts, kana_parts):
414
- text = text.replace(eng, kana, 1) # 限制每次只替换一个实例
415
-
416
- return text
417
- # 推理工具
418
- def download_unidic():
419
- try:
420
- Tagger()
421
- print("Tagger launch successfully.")
422
- except Exception as e:
423
- print("UNIDIC dictionary not found, downloading...")
424
- subprocess.run([sys.executable, "-m", "unidic", "download"])
425
- print("Download completed.")
426
-
427
- def kanji_to_hiragana(text):
428
- global tagger
429
- output = ""
430
-
431
- # 更新正则表达式以更准确地区分文本和标点符号
432
- segments = re.findall(r'[一-龥ぁ-んァ-ン\w]+|[^\一-龥ぁ-んァ-ン\w\s]', text, re.UNICODE)
433
-
434
- for segment in segments:
435
- if re.match(r'[一-龥ぁ-んァ-ン\w]+', segment):
436
- # 如果是单词或汉字,转换为平假名
437
- for word in tagger(segment):
438
- kana = word.feature.kana or word.surface
439
- hiragana = jaconv.kata2hira(kana) # 将片假名转换为平假名
440
- output += hiragana
441
- else:
442
- # 如果是标点符号,保持不变
443
- output += segment
444
-
445
- return output
446
-
447
  def get_net_g(model_path: str, device: str, hps):
448
  net_g = SynthesizerTrn(
449
  len(symbols),
@@ -498,6 +129,7 @@ def get_text(text, language_str, hps, device, style_text=None, style_weight=0.7)
498
  language = torch.LongTensor(language)
499
  return bert, ja_bert, en_bert, phone, tone, language
500
 
 
501
  def infer(
502
  text,
503
  sdp_ratio,
@@ -507,23 +139,9 @@ def infer(
507
  sid,
508
  style_text=None,
509
  style_weight=0.7,
510
- language = "Auto",
511
- mode = 'pyopenjtalk-V2.3-Katakana',
512
- skip_start=False,
513
- skip_end=False,
514
  ):
515
- if style_text == None:
516
- style_text = ""
517
- style_weight=0,
518
- if mode == 'fugashi-V2.3-Katakana':
519
- text = kanji_to_hiragana(text) if is_japanese(text) else text
520
- if language == "JP":
521
- text = translate(text,"jp")
522
- if language == "ZH":
523
- text = translate(text,"zh")
524
- if language == "Auto":
525
- language= 'JP' if is_japanese(text) else 'ZH'
526
- #print(f'{text}:{sdp_ratio}:{noise_scale}:{noise_scale_w}:{length_scale}:{length_scale}:{sid}:{language}:{mode}:{skip_start}:{skip_end}')
527
  bert, ja_bert, en_bert, phones, tones, lang_ids = get_text(
528
  text,
529
  language,
@@ -532,20 +150,6 @@ def infer(
532
  style_text=style_text,
533
  style_weight=style_weight,
534
  )
535
- if skip_start:
536
- phones = phones[3:]
537
- tones = tones[3:]
538
- lang_ids = lang_ids[3:]
539
- bert = bert[:, 3:]
540
- ja_bert = ja_bert[:, 3:]
541
- en_bert = en_bert[:, 3:]
542
- if skip_end:
543
- phones = phones[:-2]
544
- tones = tones[:-2]
545
- lang_ids = lang_ids[:-2]
546
- bert = bert[:, :-2]
547
- ja_bert = ja_bert[:, :-2]
548
- en_bert = en_bert[:, :-2]
549
  with torch.no_grad():
550
  x_tst = phones.to(device).unsqueeze(0)
551
  tones = tones.to(device).unsqueeze(0)
@@ -588,105 +192,9 @@ def infer(
588
  ) # , emo
589
  if torch.cuda.is_available():
590
  torch.cuda.empty_cache()
591
- print("Success.")
592
- return audio
593
 
594
- def loadmodel(model):
595
- _ = net_g.eval()
596
- _ = utils.load_checkpoint(model, net_g, None, skip_optimizer=True)
597
- return "success"
598
-
599
- def generate_audio_and_srt_for_group(
600
- group,
601
- outputPath,
602
- group_index,
603
- sampling_rate,
604
- speaker,
605
- sdp_ratio,
606
- noise_scale,
607
- noise_scale_w,
608
- length_scale,
609
- speakerList,
610
- silenceTime,
611
- language,
612
- mode,
613
- skip_start,
614
- skip_end,
615
- style_text,
616
- style_weight,
617
- ):
618
- audio_fin = []
619
- ass_entries = []
620
- start_time = 0
621
- #speaker = random.choice(cara_list)
622
- ass_header = """[Script Info]
623
- ; 我没意见
624
- Title: Audiobook
625
- ScriptType: v4.00+
626
- WrapStyle: 0
627
- PlayResX: 640
628
- PlayResY: 360
629
- ScaledBorderAndShadow: yes
630
- [V4+ Styles]
631
- Format: Name, Fontname, Fontsize, PrimaryColour, SecondaryColour, OutlineColour, BackColour, Bold, Italic, Underline, StrikeOut, ScaleX, ScaleY, Spacing, Angle, BorderStyle, Outline, Shadow, Alignment, MarginL, MarginR, MarginV, Encoding
632
- Style: Default,Arial,20,&H00FFFFFF,&H000000FF,&H00000000,&H00000000,0,0,0,0,100,100,0,0,1,1,1,2,10,10,10,1
633
- [Events]
634
- Format: Layer, Start, End, Style, Name, MarginL, MarginR, MarginV, Effect, Text
635
- """
636
-
637
- for sentence in group:
638
-
639
- if len(sentence) > 1:
640
- FakeSpeaker = sentence.split("|")[0]
641
- print(FakeSpeaker)
642
- SpeakersList = re.split('\n', speakerList)
643
- if FakeSpeaker in list(hps.data.spk2id.keys()):
644
- speaker = FakeSpeaker
645
- for i in SpeakersList:
646
- if FakeSpeaker == i.split("|")[1]:
647
- speaker = i.split("|")[0]
648
- if sentence != '\n':
649
- text = (remove_annotations(sentence.split("|")[-1]).replace(" ","")+"。").replace(",。","。")
650
- if mode == 'pyopenjtalk-V2.3-Katakana' or mode == 'fugashi-V2.3-Katakana':
651
- #print(f'{text}:{sdp_ratio}:{noise_scale}:{noise_scale_w}:{length_scale}:{length_scale}:{speaker}:{language}:{mode}:{skip_start}:{skip_end}')
652
- audio = infer(
653
- text,
654
- sdp_ratio,
655
- noise_scale,
656
- noise_scale_w,
657
- length_scale,
658
- speaker,
659
- style_text,
660
- style_weight,
661
- language,
662
- mode,
663
- skip_start,
664
- skip_end,
665
- )
666
- silence_frames = int(silenceTime * 44010) if is_chinese(sentence) else int(silenceTime * 44010)
667
- silence_data = np.zeros((silence_frames,), dtype=audio.dtype)
668
- audio_fin.append(audio)
669
- audio_fin.append(silence_data)
670
- duration = len(audio) / sampling_rate
671
- print(duration)
672
- end_time = start_time + duration + silenceTime
673
- ass_entries.append("Dialogue: 0,{},{},".format(seconds_to_ass_time(start_time), seconds_to_ass_time(end_time)) + "Default,,0,0,0,,{}".format(sentence.replace("|",":")))
674
- start_time = end_time
675
-
676
- wav_filename = os.path.join(outputPath, f'audiobook_part_{group_index}.wav')
677
- ass_filename = os.path.join(outputPath, f'audiobook_part_{group_index}.ass')
678
- write(wav_filename, sampling_rate, gr.processing_utils.convert_to_16_bit_wav(np.concatenate(audio_fin)))
679
-
680
- with open(ass_filename, 'w', encoding='utf-8') as f:
681
- f.write(ass_header + '\n'.join(ass_entries))
682
- return (hps.data.sampling_rate, gr.processing_utils.convert_to_16_bit_wav(np.concatenate(audio_fin)))
683
-
684
- def generate_audio(
685
- inputFile,
686
- groupsize,
687
- filepath,
688
- silenceTime,
689
- speakerList,
690
  text,
691
  sdp_ratio,
692
  noise_scale,
@@ -695,173 +203,120 @@ def generate_audio(
695
  sid,
696
  style_text=None,
697
  style_weight=0.7,
698
- language = "Auto",
699
- mode = 'pyopenjtalk-V2.3-Katakana',
700
- sentence_mode = 'sentence',
701
- skip_start=False,
702
- skip_end=False,
703
  ):
704
- if mode == 'pyopenjtalk-V2.3-Katakana' or mode == 'fugashi-V2.3-Katakana':
705
- if sentence_mode == 'sentence':
706
- audio = infer(
707
- text,
708
- sdp_ratio,
709
- noise_scale,
710
- noise_scale_w,
711
- length_scale,
712
- sid,
713
- style_text,
714
- style_weight,
715
- language,
716
- mode,
717
- skip_start,
718
- skip_end,
719
- )
720
- return (hps.data.sampling_rate,gr.processing_utils.convert_to_16_bit_wav(audio))
721
- if sentence_mode == 'paragraph':
722
- GROUP_SIZE = groupsize
723
- directory_path = filepath if torch.cuda.is_available() else "books"
724
- if os.path.exists(directory_path):
725
- shutil.rmtree(directory_path)
726
- os.makedirs(directory_path)
727
- if inputFile:
728
- text = extract_text_from_file(inputFile.name)
729
- if language == 'Auto':
730
- sentences = extrac(extract_and_convert(text))
731
- else:
732
- sentences = extrac(text)
733
- for i in range(0, len(sentences), GROUP_SIZE):
734
- group = sentences[i:i+GROUP_SIZE]
735
- if speakerList == "":
736
- speakerList = "无"
737
- result = generate_audio_and_srt_for_group(
738
- group,
739
- directory_path,
740
- i//GROUP_SIZE + 1,
741
- 44100,
742
- sid,
743
- sdp_ratio,
744
- noise_scale,
745
- noise_scale_w,
746
- length_scale,
747
- speakerList,
748
- silenceTime,
749
- language,
750
- mode,
751
- skip_start,
752
- skip_end,
753
- style_text,
754
- style_weight,
755
- )
756
- if not torch.cuda.is_available():
757
- return result
758
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
759
 
760
  Flaskapp = Flask(__name__)
761
  CORS(Flaskapp)
762
- @Flaskapp.route('/', methods=['GET', 'POST'])
 
 
763
 
764
  def tts():
765
- if request.method == 'POST':
766
- input = request.json
767
- inputFile = None
768
- filepath = input['filepath']
769
- groupSize = input['groupSize']
770
- text = input['text']
771
- sdp_ratio = input['sdp_ratio']
772
- noise_scale = input['noise_scale']
773
- noise_scale_w = input['noise_scale_w']
774
- length_scale = input['length_scale']
775
- sid = input['speaker']
776
- style_text = input['style_text']
777
- style_weight = input['style_weight']
778
- language = input['language']
779
- mode = input['mode']
780
- sentence_mode = input['sentence_mode']
781
- skip_start = input['skip_start']
782
- skip_end = input['skip_end']
783
- speakerList = input['speakerList']
784
- silenceTime = input['silenceTime']
785
- samplerate, audio = generate_audio(
786
- inputFile,
787
- groupSize,
788
- filepath,
789
- silenceTime,
790
- speakerList,
791
- text,
792
- sdp_ratio,
793
- noise_scale,
794
- noise_scale_w,
795
- length_scale,
796
- sid,
797
- style_text,
798
- style_weight,
799
- language,
800
- mode,
801
- sentence_mode,
802
- skip_start,
803
- skip_end,
804
- )
805
- unique_filename = f"temp{uuid.uuid4()}.wav"
806
- write(unique_filename, samplerate, audio)
807
- with open(unique_filename ,'rb') as bit:
808
- wav_bytes = bit.read()
809
- os.remove(unique_filename)
810
- headers = {
811
- 'Content-Type': 'audio/wav',
812
- 'Text': unique_filename .encode('utf-8')}
813
- return wav_bytes, 200, headers
814
- groupSize = request.args.get('groupSize', default = 50, type = int)
815
- text = request.args.get('text', default = '', type = str)
816
- sdp_ratio = request.args.get('sdp_ratio', default = 0.5, type = float)
817
- noise_scale = request.args.get('noise_scale', default = 0.6, type = float)
818
- noise_scale_w = request.args.get('noise_scale_w', default = 0.667, type = float)
819
- length_scale = request.args.get('length_scale', default = 1, type = float)
820
- sid = request.args.get('speaker', default = '八千代', type = str)
821
- style_text = request.args.get('style_text', default = '', type = str)
822
- style_weight = request.args.get('style_weight', default = 0.7, type = float)
823
- language = request.args.get('language', default = 'Auto', type = str)
824
- mode = request.args.get('mode', default = 'pyopenjtalk-V2.3-Katakana', type = str)
825
- sentence_mode = request.args.get('sentence_mode', default = 'sentence', type = str)
826
- skip_start = request.args.get('skip_start', default = False, type = bool)
827
- skip_end = request.args.get('skip_end', default = False, type = bool)
828
- speakerList = request.args.get('speakerList', default = '', type = str)
829
- silenceTime = request.args.get('silenceTime', default = 0.1, type = float)
830
- inputFile = None
831
- if not sid or not text:
832
- return render_template_string(f"""
833
- <!DOCTYPE html>
834
- <html>
835
- <head>
836
- <title>TTS API Documentation</title>
837
- </head>
838
- <body>
839
- <iframe src={webBase} style="width:100%; height:100vh; border:none;"></iframe>
840
- </body>
841
- </html>
842
- """)
843
- samplerate, audio = generate_audio(
844
- inputFile,
845
- groupSize,
846
- None,
847
- silenceTime,
848
- speakerList,
849
- text,
850
- sdp_ratio,
851
- noise_scale,
852
- noise_scale_w,
853
- length_scale,
854
- sid,
855
- style_text,
856
- style_weight,
857
- language,
858
- mode,
859
- sentence_mode,
860
- skip_start,
861
- skip_end,
862
- )
863
- unique_filename = f"temp{uuid.uuid4()}.wav"
864
- write(unique_filename, samplerate, audio)
865
  with open(unique_filename ,'rb') as bit:
866
  wav_bytes = bit.read()
867
  os.remove(unique_filename)
@@ -870,15 +325,97 @@ def tts():
870
  'Text': unique_filename .encode('utf-8')}
871
  return wav_bytes, 200, headers
872
 
 
 
873
 
874
  if __name__ == "__main__":
875
- download_unidic()
876
- tagger = Tagger()
 
 
 
 
877
  net_g = get_net_g(
878
  model_path=modelPaths[-1], device=device, hps=hps
879
  )
880
  speaker_ids = hps.data.spk2id
881
  speakers = list(speaker_ids.keys())
882
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
883
  print("推理页面已开启!")
884
- Flaskapp.run(host="0.0.0.0", port=8080,debug=True)
 
 
3
  from pathlib import Path
4
 
5
  import logging
 
6
  import re_matching
7
 
8
  logging.getLogger("numba").setLevel(logging.WARNING)
 
15
  )
16
 
17
  logger = logging.getLogger(__name__)
18
+
 
19
  import librosa
20
  import numpy as np
21
  import torch
 
23
  from torch.utils.data import Dataset
24
  from torch.utils.data import DataLoader, Dataset
25
  from tqdm import tqdm
26
+ from clap_wrapper import get_clap_audio_feature, get_clap_text_feature
27
+
28
+ import uuid
29
+ from flask import Flask, request, jsonify, render_template_string
30
+ from flask_cors import CORS
31
 
32
  import gradio as gr
33
 
 
43
  from models import SynthesizerTrn
44
  from text.symbols import symbols
45
  import sys
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  from scipy.io.wavfile import write
47
+ from threading import Thread
48
+
49
  net_g = None
50
 
51
  device = (
 
75
  "西克菲尔特音乐学院":["晶","未知留","八千代","栞","美帆"]
76
  }
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  def get_net_g(model_path: str, device: str, hps):
79
  net_g = SynthesizerTrn(
80
  len(symbols),
 
129
  language = torch.LongTensor(language)
130
  return bert, ja_bert, en_bert, phone, tone, language
131
 
132
+
133
  def infer(
134
  text,
135
  sdp_ratio,
 
139
  sid,
140
  style_text=None,
141
  style_weight=0.7,
 
 
 
 
142
  ):
143
+
144
+ language= 'JP' if is_japanese(text) else 'ZH'
 
 
 
 
 
 
 
 
 
 
145
  bert, ja_bert, en_bert, phones, tones, lang_ids = get_text(
146
  text,
147
  language,
 
150
  style_text=style_text,
151
  style_weight=style_weight,
152
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  with torch.no_grad():
154
  x_tst = phones.to(device).unsqueeze(0)
155
  tones = tones.to(device).unsqueeze(0)
 
192
  ) # , emo
193
  if torch.cuda.is_available():
194
  torch.cuda.empty_cache()
195
+ return (hps.data.sampling_rate,gr.processing_utils.convert_to_16_bit_wav(audio))
 
196
 
197
+ def inferAPI(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  text,
199
  sdp_ratio,
200
  noise_scale,
 
203
  sid,
204
  style_text=None,
205
  style_weight=0.7,
 
 
 
 
 
206
  ):
207
+
208
+ language= 'JP' if is_japanese(text) else 'ZH'
209
+ bert, ja_bert, en_bert, phones, tones, lang_ids = get_text(
210
+ text,
211
+ language,
212
+ hps,
213
+ device,
214
+ style_text=style_text,
215
+ style_weight=style_weight,
216
+ )
217
+ with torch.no_grad():
218
+ x_tst = phones.to(device).unsqueeze(0)
219
+ tones = tones.to(device).unsqueeze(0)
220
+ lang_ids = lang_ids.to(device).unsqueeze(0)
221
+ bert = bert.to(device).unsqueeze(0)
222
+ ja_bert = ja_bert.to(device).unsqueeze(0)
223
+ en_bert = en_bert.to(device).unsqueeze(0)
224
+ x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
225
+ # emo = emo.to(device).unsqueeze(0)
226
+ del phones
227
+ speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device)
228
+ audio = (
229
+ net_g.infer(
230
+ x_tst,
231
+ x_tst_lengths,
232
+ speakers,
233
+ tones,
234
+ lang_ids,
235
+ bert,
236
+ ja_bert,
237
+ en_bert,
238
+ sdp_ratio=sdp_ratio,
239
+ noise_scale=noise_scale,
240
+ noise_scale_w=noise_scale_w,
241
+ length_scale=length_scale,
242
+ )[0][0, 0]
243
+ .data.cpu()
244
+ .float()
245
+ .numpy()
246
+ )
247
+ del (
248
+ x_tst,
249
+ tones,
250
+ lang_ids,
251
+ bert,
252
+ x_tst_lengths,
253
+ speakers,
254
+ ja_bert,
255
+ en_bert,
256
+ ) # , emo
257
+ if torch.cuda.is_available():
258
+ torch.cuda.empty_cache()
259
+ unique_filename = f"temp{uuid.uuid4()}.wav"
260
+ write(unique_filename, 44100, audio)
261
+ return unique_filename
262
+
263
+ def is_japanese(string):
264
+ for ch in string:
265
+ if ord(ch) > 0x3040 and ord(ch) < 0x30FF:
266
+ return True
267
+ return False
268
+
269
+ def loadmodel(model):
270
+ try:
271
+ _ = net_g.eval()
272
+ _ = utils.load_checkpoint(model, net_g, None, skip_optimizer=True)
273
+ return "success"
274
+ except:
275
+ return "error"
276
 
277
  Flaskapp = Flask(__name__)
278
  CORS(Flaskapp)
279
+ @Flaskapp.route('/')
280
+
281
+ @Flaskapp.route('/')
282
 
283
  def tts():
284
+ global last_text, last_model
285
+ speaker = request.args.get('speaker')
286
+ sdp_ratio = float(request.args.get('sdp_ratio', 0.2))
287
+ noise_scale = float(request.args.get('noise_scale', 0.6))
288
+ noise_scale_w = float(request.args.get('noise_scale_w', 0.8))
289
+ length_scale = float(request.args.get('length_scale', 1))
290
+ style_weight = float(request.args.get('style_weight', 0.7))
291
+ style_text = request.args.get('style_text', 'happy')
292
+ text = request.args.get('text')
293
+ is_chat = request.args.get('is_chat', 'false').lower() == 'true'
294
+ model = request.args.get('model',modelPaths[-1])
295
+
296
+ if not speaker or not text:
297
+ return render_template_string("""
298
+ <!DOCTYPE html>
299
+ <html>
300
+ <head>
301
+ <title>TTS API Documentation</title>
302
+ </head>
303
+ <body>
304
+ <iframe src="http://127.0.0.1:7860" style="width:100%; height:100vh; border:none;"></iframe>
305
+ </body>
306
+ </html>
307
+ """)
308
+
309
+ if model != last_model:
310
+ unique_filename = loadmodel(model)
311
+ last_model = model
312
+ if is_chat and text == last_text:
313
+ # Generate 1 second of silence and return
314
+ unique_filename = 'blank.wav'
315
+ silence = np.zeros(44100, dtype=np.int16)
316
+ write(unique_filename , 44100, silence)
317
+ else:
318
+ last_text = text
319
+ unique_filename = inferAPI(text, sdp_ratio=sdp_ratio, noise_scale=noise_scale, noise_scale_w=noise_scale_w, length_scale=length_scale,sid = speaker, style_text=style_text, style_weight=style_weight)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
320
  with open(unique_filename ,'rb') as bit:
321
  wav_bytes = bit.read()
322
  os.remove(unique_filename)
 
325
  'Text': unique_filename .encode('utf-8')}
326
  return wav_bytes, 200, headers
327
 
328
+ def gradio_interface():
329
+ return app.launch(share=True)
330
 
331
  if __name__ == "__main__":
332
+ languages = [ "Auto", "ZH", "JP"]
333
+ modelPaths = []
334
+ for dirpath, dirnames, filenames in os.walk('Data/Chinese/models/'):
335
+ for filename in filenames:
336
+ modelPaths.append(os.path.join(dirpath, filename))
337
+ hps = utils.get_hparams_from_file('Data/Chinese/config.json')
338
  net_g = get_net_g(
339
  model_path=modelPaths[-1], device=device, hps=hps
340
  )
341
  speaker_ids = hps.data.spk2id
342
  speakers = list(speaker_ids.keys())
343
+ last_text = ""
344
+ last_model = modelPaths[-1]
345
+ with gr.Blocks() as app:
346
+ for band in BandList:
347
+ with gr.TabItem(band):
348
+ for name in BandList[band]:
349
+ with gr.TabItem(name):
350
+ with gr.Row():
351
+ with gr.Column():
352
+ with gr.Row():
353
+ gr.Markdown(
354
+ '<div align="center">'
355
+ f'<img style="width:auto;height:400px;" src="https://mahiruoshi-bangdream-bert-vits2.hf.space/file/image/{name}.png">'
356
+ '</div>'
357
+ )
358
+ length_scale = gr.Slider(
359
+ minimum=0.1, maximum=2, value=1, step=0.01, label="语速调节"
360
+ )
361
+ with gr.Accordion(label="参数设定", open=False):
362
+ sdp_ratio = gr.Slider(
363
+ minimum=0, maximum=1, value=0.5, step=0.01, label="SDP/DP混合比"
364
+ )
365
+ noise_scale = gr.Slider(
366
+ minimum=0.1, maximum=2, value=0.6, step=0.01, label="感情调节"
367
+ )
368
+ noise_scale_w = gr.Slider(
369
+ minimum=0.1, maximum=2, value=0.667, step=0.01, label="音素长度"
370
+ )
371
+ speaker = gr.Dropdown(
372
+ choices=speakers, value=name, label="说话人"
373
+ )
374
+ with gr.Accordion(label="切换模型", open=False):
375
+ modelstrs = gr.Dropdown(label = "模型", choices = modelPaths, value = modelPaths[0], type = "value")
376
+ btnMod = gr.Button("载入模型")
377
+ statusa = gr.TextArea()
378
+ btnMod.click(loadmodel, inputs=[modelstrs], outputs = [statusa])
379
+ with gr.Column():
380
+ text = gr.TextArea(
381
+ label="输入纯日语或者中文",
382
+ placeholder="输入纯日语或者中文",
383
+ value="为什么要演奏春日影!",
384
+ )
385
+ style_text = gr.Textbox(label="辅助文本")
386
+ style_weight = gr.Slider(
387
+ minimum=0,
388
+ maximum=1,
389
+ value=0.7,
390
+ step=0.1,
391
+ label="Weight",
392
+ info="主文本和辅助文本的bert混合比率,0表示仅主文本,1表示仅辅助文本",
393
+ )
394
+ btn = gr.Button("点击生成", variant="primary")
395
+ audio_output = gr.Audio(label="Output Audio")
396
+ '''
397
+ btntran = gr.Button("快速中翻日")
398
+ translateResult = gr.TextArea("从这复制翻译后的文本")
399
+ btntran.click(translate, inputs=[text], outputs = [translateResult])
400
+ '''
401
+ btn.click(
402
+ infer,
403
+ inputs=[
404
+ text,
405
+ sdp_ratio,
406
+ noise_scale,
407
+ noise_scale_w,
408
+ length_scale,
409
+ speaker,
410
+ style_text,
411
+ style_weight,
412
+ ],
413
+ outputs=[audio_output],
414
+ )
415
+
416
+ api_thread = Thread(target=Flaskapp.run, args=("0.0.0.0", 8000))
417
+ gradio_thread = Thread(target=gradio_interface)
418
+ gradio_thread.start()
419
  print("推理页面已开启!")
420
+ api_thread.start()
421
+ print("api页面已开启!运行在8000端口")