kevinwang676 commited on
Commit
1f224f2
1 Parent(s): 7a44c08

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -124
app.py CHANGED
@@ -1,37 +1,14 @@
1
- import librosa
2
- import matplotlib.pyplot as plt
3
-
4
  import os
5
- import json
6
- import math
7
-
8
- import requests
9
  import torch
10
- from torch import nn
11
- from torch.nn import functional as F
12
- from torch.utils.data import DataLoader
13
-
14
  import commons
15
  import utils
16
- from data_utils import TextAudioLoader, TextAudioCollate, TextAudioSpeakerLoader, TextAudioSpeakerCollate
17
  from models import SynthesizerTrn
18
  from text.symbols import symbols
19
  from text import text_to_sequence
20
- import langdetect
21
-
22
  from scipy.io.wavfile import write
23
- import re
24
- from scipy import signal
25
- import gradio as gr
26
-
27
- '''
28
- from phonemizer.backend.espeak.wrapper import EspeakWrapper
29
- _ESPEAK_LIBRARY = 'C:\Program Files\eSpeak NG\libespeak-ng.dll'
30
- EspeakWrapper.set_library(_ESPEAK_LIBRARY)
31
- '''
32
- # check device
33
- device = "cuda:0" if torch.cuda.is_available() else "cpu"
34
-
35
 
36
  def get_text(text, hps):
37
  text_norm = text_to_sequence(text, hps.data.text_cleaners)
@@ -40,109 +17,54 @@ def get_text(text, hps):
40
  text_norm = torch.LongTensor(text_norm)
41
  return text_norm
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
- def langdetector(text): # from PolyLangVITS
45
- try:
46
- lang = langdetect.detect(text).lower()
47
- if lang == 'ko':
48
- return f'[KO]{text}[KO]'
49
- elif lang == 'ja':
50
- return f'[JA]{text}[JA]'
51
- elif lang == 'en':
52
- return f'[EN]{text}[EN]'
53
- elif lang == 'zh-cn':
54
- return f'[ZH]{text}[ZH]'
55
- else:
56
- return text
57
- except Exception as e:
58
- return text
59
-
60
-
61
- def vcss(inputstr): # single
62
- fltstr = re.sub(r"[\[\]\(\)\{\}]", "", inputstr)
63
- #fltstr = langdetector(fltstr) #- optional for cjke/cjks type cleaners
64
- stn_tst = get_text(fltstr, hps)
65
-
66
- speed = 1
67
- output_dir = 'output'
68
- sid = 0
69
  with torch.no_grad():
70
- x_tst = stn_tst.to(device).unsqueeze(0)
71
- x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(device)
72
- audio = net_g.infer(x_tst, x_tst_lengths, noise_scale=.667, noise_scale_w=0.8, length_scale=1 / speed)[0][
73
- 0, 0].data.cpu().float().numpy()
74
- write("output.wav", hps.data.sampling_rate, audio)
75
-
76
- return "output.wav"
77
-
78
- """
79
- def vcms(inputstr, sid):
80
- fltstr = re.sub(r"[\[\]\(\)\{\}]", "", inputstr)
81
- fltstr = langdetector(fltstr)
82
- stn_tst = get_text(fltstr, hps)
83
-
84
- speed = 1
85
- output_dir = 'output'
86
- with torch.no_grad():
87
- x_tst = stn_tst.to(device).unsqueeze(0)
88
- x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(device)
89
- sid = torch.LongTensor([sid]).to(device)
90
- audio = net_g.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=.667, noise_scale_w=0.8, length_scale=1 / speed)[0][
91
- 0, 0].data.cpu().float().numpy()
92
- write(f'./{output_dir}/output_{sid}.wav', hps.data.sampling_rate, audio)
93
- print(f'./{output_dir}/output_{sid}.wav Generated!')
94
- """
95
-
96
-
97
-
98
- hps = utils.get_hparams_from_file("./configs/config.json")
99
-
100
- if "use_mel_posterior_encoder" in hps.model.keys() and hps.model.use_mel_posterior_encoder == True:
101
- print("Using mel posterior encoder for VITS2")
102
- posterior_channels = 80 # vits2
103
- hps.data.use_mel_posterior_encoder = True
104
- else:
105
- print("Using lin posterior encoder for VITS1")
106
- posterior_channels = hps.data.filter_length // 2 + 1
107
- hps.data.use_mel_posterior_encoder = False
108
-
109
- net_g = SynthesizerTrn(
110
- len(symbols),
111
- posterior_channels,
112
- hps.train.segment_size // hps.data.hop_length,
113
- # n_speakers=hps.data.n_speakers, #- for multi speaker
114
- **hps.model).to(device)
115
- _ = net_g.eval()
116
-
117
- _ = utils.load_checkpoint("./logs/G_6100.pth", net_g, None)
118
 
119
- # - text input
 
120
 
121
- def infer(text):
122
-
123
- return vcss(text)
124
 
125
- app = gr.Blocks()
 
 
 
 
126
 
127
- with app:
128
- gr.Markdown("# <center>🥳🎶🎡 - VITS2真实拟声</center>")
129
- gr.Markdown("## <center>🌟 - 稻妻神里流太刀术皆传 神里绫华参上 </center>")
130
- gr.Markdown("### <center>🌊 - 更多精彩应用,敬请关注[滔滔AI](http://www.talktalkai.com);滔滔AI,为爱滔滔!💕</center>")
131
 
132
- with gr.Row():
133
- with gr.Column():
134
- inp1 = gr.Textbox(label="请在这里填写您想合成的文本", placeholder="想说却还没说的 还很多...", lines=3)
135
- btn1 = gr.Button("3.一键推理", variant="primary")
136
- with gr.Column():
137
- out1 = gr.Audio(type="filepath", label="为您合成的神里绫华语音")
138
 
139
- btn1.click(infer, inp1, out1)
140
-
141
- gr.Markdown("### <center>注意❗:请不要生成会对个人以及组织造成侵害的内容,此程序仅供科研、学习及个人娱乐使用。</center>")
142
- gr.HTML('''
143
- <div class="footer">
144
- <p>🌊🏞️🎶 - 江水东流急,滔滔无尽声。 明·顾璘
145
- </p>
146
- </div>
147
- ''')
148
- app.launch(show_error=True)
 
1
+ import argparse
2
+ import gradio as gr
3
+ from gradio import components
4
  import os
 
 
 
 
5
  import torch
 
 
 
 
6
  import commons
7
  import utils
 
8
  from models import SynthesizerTrn
9
  from text.symbols import symbols
10
  from text import text_to_sequence
 
 
11
  from scipy.io.wavfile import write
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  def get_text(text, hps):
14
  text_norm = text_to_sequence(text, hps.data.text_cleaners)
 
17
  text_norm = torch.LongTensor(text_norm)
18
  return text_norm
19
 
20
+ def tts(model_path, config_path, text):
21
+ model_path = "./logs/G_23300.pth"
22
+ config_path = "./configs/config.json"
23
+ hps = utils.get_hparams_from_file(config_path)
24
+
25
+ if "use_mel_posterior_encoder" in hps.model.keys() and hps.model.use_mel_posterior_encoder == True:
26
+ posterior_channels = 80
27
+ hps.data.use_mel_posterior_encoder = True
28
+ else:
29
+ posterior_channels = hps.data.filter_length // 2 + 1
30
+ hps.data.use_mel_posterior_encoder = False
31
+
32
+ net_g = SynthesizerTrn(
33
+ len(symbols),
34
+ posterior_channels,
35
+ hps.train.segment_size // hps.data.hop_length,
36
+ **hps.model).cuda()
37
+ _ = net_g.eval()
38
+ _ = utils.load_checkpoint(model_path, net_g, None)
39
+
40
+ stn_tst = get_text(text, hps)
41
+ x_tst = stn_tst.cuda().unsqueeze(0)
42
+ x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).cuda()
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  with torch.no_grad():
45
+ audio = net_g.infer(x_tst, x_tst_lengths, noise_scale=.667, noise_scale_w=0.8, length_scale=1)[0][0,0].data.cpu().float().numpy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
+ output_wav_path = "output.wav"
48
+ write(output_wav_path, hps.data.sampling_rate, audio)
49
 
50
+ return output_wav_path
 
 
51
 
52
+ if __name__ == "__main__":
53
+ parser = argparse.ArgumentParser()
54
+ parser.add_argument('--model_path', type=str, default="./logs/G_23300.pth", help='Path to the model file.')
55
+ parser.add_argument('--config_path', type=str, default="./configs/config.json", help='Path to the config file.')
56
+ args = parser.parse_args()
57
 
58
+ model_files = [f for f in os.listdir('./logs/') if f.endswith('.pth')]
59
+ model_files.sort(key=lambda x: int(x.split('_')[-1].split('.')[0]), reverse=True)
60
+ config_files = [f for f in os.listdir('./configs/') if f.endswith('.json')]
 
61
 
62
+ default_model_file = args.model_path if args.model_path else (model_files[0] if model_files else None)
63
+ default_config_file = args.config_path if args.config_path else 'config.json'
 
 
 
 
64
 
65
+ gr.Interface(
66
+ fn=tts,
67
+ inputs=components.Textbox(label="Text Input"),
68
+ outputs=components.Audio(type='filepath', label="Generated Speech"),
69
+ live=False
70
+ ).launch(show_error=True)