wetdog commited on
Commit
7602717
1 Parent(s): 396b4b5

add gradio app

Browse files
Files changed (1) hide show
  1. app.py +150 -0
app.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## VCTK
2
+ import torch
3
+ import os
4
+
5
+ import commons
6
+ import utils
7
+ from models import SynthesizerTrn
8
+ from text.symbols import symbols
9
+ from text import text_to_sequence
10
+
11
+ from scipy.io.wavfile import write
12
+ import gradio as gr
13
+
14
+ print("Running GRadio", gr.__version__)
15
+
16
+ model_path = "vits2_pytorch/G_390000.pth"
17
+ config_path = "vits2_pytorch/vits2_vctk_cat_inference.json"
18
+
19
+ hps = utils.get_hparams_from_file(config_path)
20
+
21
+ if (
22
+ "use_mel_posterior_encoder" in hps.model.keys()
23
+ and hps.model.use_mel_posterior_encoder == True
24
+ ):
25
+ print("Using mel posterior encoder for VITS2")
26
+ posterior_channels = 80 # vits2
27
+ hps.data.use_mel_posterior_encoder = True
28
+ else:
29
+ print("Using lin posterior encoder for VITS1")
30
+ posterior_channels = hps.data.filter_length // 2 + 1
31
+ hps.data.use_mel_posterior_encoder = False
32
+
33
+ net_g = SynthesizerTrn(
34
+ len(symbols),
35
+ posterior_channels,
36
+ hps.train.segment_size // hps.data.hop_length,
37
+ n_speakers=hps.data.n_speakers,
38
+ **hps.model
39
+ )
40
+ _ = net_g.eval()
41
+
42
+ _ = utils.load_checkpoint(model_path, net_g, None)
43
+
44
+
45
+ def get_text(text, hps):
46
+ text_norm = text_to_sequence(text, hps.data.text_cleaners)
47
+ #text_norm = cleaned_text_to_sequence(text) # if model was trained with text
48
+
49
+ if hps.data.add_blank:
50
+ text_norm = commons.intersperse(text_norm, 0)
51
+ text_norm = torch.LongTensor(text_norm)
52
+ return text_norm
53
+
54
+ def tts(text:str, speaker_id:int, speed:float, noise_scale:float=0.667, noise_scale_w:float=0.8):
55
+
56
+ stn_tst = get_text(text, hps)
57
+ with torch.no_grad():
58
+ x_tst = stn_tst.unsqueeze(0)
59
+ x_tst_lengths = torch.LongTensor([stn_tst.size(0)])
60
+ sid = torch.LongTensor([speaker_id])
61
+ waveform = (
62
+ net_g.infer(
63
+ x_tst,
64
+ x_tst_lengths,
65
+ sid=sid,
66
+ noise_scale=noise_scale,
67
+ noise_scale_w=noise_scale_w,
68
+ length_scale=1/speed,
69
+ )[0][0, 0]
70
+ .data.cpu()
71
+ .float()
72
+ .numpy()
73
+ )
74
+
75
+ return gr.make_waveform((22050, waveform))
76
+
77
+ ## GUI space
78
+
79
+ title = """
80
+ <div style="text-align: center; max-width: 700px; margin: 0 auto;">
81
+ <div
82
+ style="display: inline-flex; align-items: center; gap: 0.8rem; font-size: 1.75rem;"
83
+ > <h1 style="font-weight: 900; margin-bottom: 7px; line-height: normal;">
84
+ VITS2 TTS Catalan Demo
85
+ </h1> </div>
86
+ </div>
87
+ """
88
+
89
+ description = """
90
+ VITS2 is an end-to-end speech synthesis model that predicts a speech waveform conditional on an input text sequence. VITS2 improved the
91
+ training and inference efficiency and naturalness by introducing adversarial learning into the duration predictor. The transformer
92
+ block was added to the normalizing flows to capture the long-term dependency when transforming the distribution.
93
+ The synthesis quality was improved by incorporating Gaussian noise into the alignment search.
94
+
95
+ This model is being trained in openslr69 and festcat datasets
96
+ """
97
+
98
+ article = "Model by Jungil Kong, et al. from SK telecom. Demo by BSC."
99
+
100
+ vits2_inference = gr.Interface(
101
+ fn=tts,
102
+ inputs=[
103
+ gr.Textbox(
104
+ value="m'ha costat desenvolupar molt una veu, i ara que la tinc no estaré en silenci.",
105
+ max_lines=1,
106
+ label="Input text",
107
+ ),
108
+ gr.Slider(
109
+ 1,
110
+ 47,
111
+ value=10,
112
+ step=1,
113
+ label="Speaker id",
114
+ info=f"This model is trained on 47 speakers. You can prompt the model using one of these speaker ids.",
115
+ ),
116
+ gr.Slider(
117
+ 0.5,
118
+ 1.5,
119
+ value=1,
120
+ step=0.1,
121
+ label="Speed",
122
+ ),
123
+ gr.Slider(
124
+ 0.2,
125
+ 2.0,
126
+ value=0.667,
127
+ step=0.01,
128
+ label="Noise scale",
129
+ ),
130
+ gr.Slider(
131
+ 0.2,
132
+ 2.0,
133
+ value=0.8,
134
+ step=0.01,
135
+ label="Noise scale w",
136
+ ),
137
+ ],
138
+ outputs=gr.Audio(),
139
+ )
140
+
141
+ demo = gr.Blocks()
142
+
143
+ with demo:
144
+ gr.Markdown(title)
145
+ gr.Markdown(description)
146
+ gr.TabbedInterface([vits2_inference], ["Multispeaker"])
147
+ gr.Markdown(article)
148
+
149
+ demo.queue(max_size=10)
150
+ demo.launch(show_api=False, server_name="0.0.0.0", server_port=7860)