ulysses115 commited on
Commit
2d8ad0f
1 Parent(s): 1c96e2d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +191 -1
app.py CHANGED
@@ -1,3 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- gr.Interface.load("models/ulysses115/pmvoice").launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import gradio as gr
2
+
3
+ # gr.Interface.load("models/ulysses115/pmvoice").launch()
4
+
5
+ import argparse
6
+ import json
7
+ import os
8
+ import re
9
+ import tempfile
10
+
11
+ import librosa
12
+ import numpy as np
13
+ import torch
14
+ from torch import no_grad, LongTensor
15
+ import commons
16
+ import utils
17
  import gradio as gr
18
+ import gradio.utils as gr_utils
19
+ import gradio.processing_utils as gr_processing_utils
20
+ from models import SynthesizerTrn
21
+ from text import text_to_sequence, _clean_text
22
+ from mel_processing import spectrogram_torch
23
+
24
+ limitation = False#os.getenv("SYSTEM") == "spaces" # limit text and audio length in huggingface spaces
25
+
26
+
27
+ def audio_postprocess(self, y):
28
+ if y is None:
29
+ return None
30
+
31
+ if gr_utils.validate_url(y):
32
+ file = gr_processing_utils.download_to_file(y, dir=self.temp_dir)
33
+ elif isinstance(y, tuple):
34
+ sample_rate, data = y
35
+ file = tempfile.NamedTemporaryFile(
36
+ suffix=".wav", dir=self.temp_dir, delete=False
37
+ )
38
+ gr_processing_utils.audio_to_file(sample_rate, data, file.name)
39
+ else:
40
+ file = gr_processing_utils.create_tmp_copy_of_file(y, dir=self.temp_dir)
41
+
42
+ return gr_processing_utils.encode_url_or_file_to_base64(file.name)
43
+
44
+
45
+ gr.Audio.postprocess = audio_postprocess
46
+
47
+
48
+ def get_text(text, hps, is_symbol):
49
+ text_norm = text_to_sequence(text, hps.symbols, [] if is_symbol else hps.data.text_cleaners)
50
+ if hps.data.add_blank:
51
+ text_norm = commons.intersperse(text_norm, 0)
52
+ text_norm = LongTensor(text_norm)
53
+ return text_norm
54
+
55
+
56
+ def create_tts_fn(model, hps, speaker_ids):
57
+ def tts_fn(text, speaker, speed, is_symbol):
58
+ if limitation:
59
+ text_len = len(re.sub("\[([A-Z]{2})\]", "", text))
60
+ max_len = 150
61
+ if is_symbol:
62
+ max_len *= 3
63
+ if text_len > max_len:
64
+ return "Error: Text is too long", None
65
+
66
+ speaker_id = speaker_ids[speaker]
67
+ stn_tst = get_text(text, hps, is_symbol)
68
+ with no_grad():
69
+ x_tst = stn_tst.unsqueeze(0).to(device)
70
+ x_tst_lengths = LongTensor([stn_tst.size(0)]).to(device)
71
+ sid = LongTensor([speaker_id]).to(device)
72
+ audio = model.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=.667, noise_scale_w=0.8,
73
+ length_scale=1.0 / speed)[0][0, 0].data.cpu().float().numpy()
74
+ del stn_tst, x_tst, x_tst_lengths, sid
75
+ return "Success", (hps.data.sampling_rate, audio)
76
+
77
+ return tts_fn
78
+
79
+
80
+ def create_to_symbol_fn(hps):
81
+ def to_symbol_fn(is_symbol_input, input_text, temp_text):
82
+ return (_clean_text(input_text, hps.data.text_cleaners), input_text) if is_symbol_input \
83
+ else (temp_text, temp_text)
84
+
85
+ return to_symbol_fn
86
+
87
+
88
+ download_audio_js = """
89
+ () =>{{
90
+ let root = document.querySelector("body > gradio-app");
91
+ if (root.shadowRoot != null)
92
+ root = root.shadowRoot;
93
+ let audio = root.querySelector("#{audio_id}").querySelector("audio");
94
+ if (audio == undefined)
95
+ return;
96
+ audio = audio.src;
97
+ let oA = document.createElement("a");
98
+ oA.download = Math.floor(Math.random()*100000000)+'.wav';
99
+ oA.href = audio;
100
+ document.body.appendChild(oA);
101
+ oA.click();
102
+ oA.remove();
103
+ }}
104
+ """
105
+
106
+ if __name__ == '__main__':
107
+ parser = argparse.ArgumentParser()
108
+ parser.add_argument('--device', type=str, default='cpu')
109
+ parser.add_argument("--share", action="store_true", default=True, help="share gradio app")
110
+ args = parser.parse_args()
111
+
112
+ device = torch.device(args.device)
113
+ models_tts = []
114
+ with open("saved_model/info.json", "r", encoding="utf-8") as f:
115
+ models_info = json.load(f)
116
+ for i, info in models_info.items():
117
+ name = info["title"]
118
+ author = info["author"]
119
+ lang = info["lang"]
120
+ example = info["example"]
121
+ config_path = f"saved_model/{i}/config.json"
122
+ model_path = f"saved_model/{i}/model.pth"
123
+ cover = info["cover"]
124
+ cover_path = f"saved_model/{i}/{cover}" if cover else None
125
+ hps = utils.get_hparams_from_file(config_path)
126
+ model = SynthesizerTrn(
127
+ len(hps.symbols),
128
+ hps.data.filter_length // 2 + 1,
129
+ hps.train.segment_size // hps.data.hop_length,
130
+ n_speakers=hps.data.n_speakers,
131
+ **hps.model)
132
+ utils.load_checkpoint(model_path, model, None)
133
+ model.eval().to(device)
134
+ speaker_ids = [sid for sid, name in enumerate(hps.speakers) if name != "None"]
135
+ speakers = [name for sid, name in enumerate(hps.speakers) if name != "None"]
136
+
137
+ t = info["type"]
138
+ if t == "vits":
139
+ models_tts.append((name, author, cover_path, speakers, lang, example,
140
+ hps.symbols, create_tts_fn(model, hps, speaker_ids),
141
+ create_to_symbol_fn(hps)))
142
+
143
+ app = gr.Blocks()
144
+
145
+ with app:
146
+ for i, (name, author, cover_path, speakers, lang, example, symbols, tts_fn,
147
+ to_symbol_fn) in enumerate(models_tts):
148
+ with gr.TabItem(f"model{i}"):
149
+ with gr.Column():
150
+ tts_input1 = gr.TextArea(label="Text (150 words limitation)", value=example,
151
+ elem_id=f"tts-input{i}")
152
+ tts_input2 = gr.Dropdown(label="Speaker", choices=speakers,
153
+ type="index", value=speakers[0])
154
+ tts_input3 = gr.Slider(label="Speed", value=1, minimum=0.5, maximum=2, step=0.1)
155
+ with gr.Accordion(label="Advanced Options", open=False):
156
+ temp_text_var = gr.Variable()
157
+ symbol_input = gr.Checkbox(value=False, label="Symbol input")
158
+ symbol_list = gr.Dataset(label="Symbol list", components=[tts_input1],
159
+ samples=[[x] for x in symbols],
160
+ elem_id=f"symbol-list{i}")
161
+ symbol_list_json = gr.Json(value=symbols, visible=False)
162
+ tts_submit = gr.Button("Generate", variant="primary")
163
+ tts_output1 = gr.Textbox(label="Output Message")
164
+ tts_output2 = gr.Audio(label="Output Audio", elem_id=f"tts-audio{i}")
165
+ download = gr.Button("Download Audio")
166
+ download.click(None, [], [], _js=download_audio_js.format(audio_id=f"tts-audio{i}"))
167
 
168
+ tts_submit.click(tts_fn, [tts_input1, tts_input2, tts_input3, symbol_input],
169
+ [tts_output1, tts_output2])
170
+ symbol_input.change(to_symbol_fn,
171
+ [symbol_input, tts_input1, temp_text_var],
172
+ [tts_input1, temp_text_var])
173
+ symbol_list.click(None, [symbol_list, symbol_list_json], [],
174
+ _js=f"""
175
+ (i,symbols) => {{
176
+ let root = document.querySelector("body > gradio-app");
177
+ if (root.shadowRoot != null)
178
+ root = root.shadowRoot;
179
+ let text_input = root.querySelector("#tts-input{i}").querySelector("textarea");
180
+ let startPos = text_input.selectionStart;
181
+ let endPos = text_input.selectionEnd;
182
+ let oldTxt = text_input.value;
183
+ let result = oldTxt.substring(0, startPos) + symbols[i] + oldTxt.substring(endPos);
184
+ text_input.value = result;
185
+ let x = window.scrollX, y = window.scrollY;
186
+ text_input.focus();
187
+ text_input.selectionStart = startPos + symbols[i].length;
188
+ text_input.selectionEnd = startPos + symbols[i].length;
189
+ text_input.blur();
190
+ window.scrollTo(x, y);
191
+ return [];
192
+ }}""")
193
+ app.queue(concurrency_count=1).launch(show_api=True, share=args.share)