FrankZxShen commited on
Commit
f674379
1 Parent(s): 49c1e2f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +158 -75
app.py CHANGED
@@ -20,36 +20,40 @@ language_marks = {
20
  "English": "[EN]",
21
  "Mix": "",
22
  }
23
- lang = ['日本語', '简体中文', 'English', 'Mix']
24
 
 
25
 
26
  def get_text(text, hps, is_symbol):
27
- text_norm = text_to_sequence(
28
- text, hps.symbols, [] if is_symbol else hps.data.text_cleaners)
29
  if hps.data.add_blank:
30
  text_norm = commons.intersperse(text_norm, 0)
31
  text_norm = LongTensor(text_norm)
32
  return text_norm
33
 
34
-
35
  def create_tts_fn(model, hps, speaker_ids):
36
- def tts_fn(text, speaker, language, speed):
 
 
 
 
 
 
 
37
  if language is not None:
38
  text = language_marks[language] + text + language_marks[language]
39
  speaker_id = speaker_ids[speaker]
40
- stn_tst = get_text(text, hps, False)
41
  with no_grad():
42
  x_tst = stn_tst.unsqueeze(0).to(device)
43
  x_tst_lengths = LongTensor([stn_tst.size(0)]).to(device)
44
  sid = LongTensor([speaker_id]).to(device)
45
- audio = model.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=.667, noise_scale_w=0.8,
46
  length_scale=1.0 / speed)[0][0, 0].data.cpu().float().numpy()
47
  del stn_tst, x_tst, x_tst_lengths, sid
48
  return "Success", (hps.data.sampling_rate, audio)
49
 
50
  return tts_fn
51
 
52
-
53
  def create_vc_fn(model, hps, speaker_ids):
54
  def vc_fn(original_speaker, target_speaker, record_audio, upload_audio):
55
  input_audio = record_audio if record_audio is not None else upload_audio
@@ -63,8 +67,7 @@ def create_vc_fn(model, hps, speaker_ids):
63
  if len(audio.shape) > 1:
64
  audio = librosa.to_mono(audio.transpose(1, 0))
65
  if sampling_rate != hps.data.sampling_rate:
66
- audio = librosa.resample(
67
- audio, orig_sr=sampling_rate, target_sr=hps.data.sampling_rate)
68
  with no_grad():
69
  y = torch.FloatTensor(audio)
70
  y = y / max(-y.min(), y.max()) / 0.99
@@ -83,32 +86,70 @@ def create_vc_fn(model, hps, speaker_ids):
83
 
84
  return vc_fn
85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  if __name__ == "__main__":
88
  parser = argparse.ArgumentParser()
89
- parser.add_argument("--model_dir", default="./models/G_9700.pth",
90
- help="directory to your fine-tuned model")
91
- parser.add_argument("--config_dir", default="./configs/modified_finetune_speaker.json",
92
- help="directory to your model config file")
93
- parser.add_argument("--share", action="store_true", default=False,
94
- help="make link public (used in colab)")
95
-
96
  args = parser.parse_args()
97
- hps = utils.get_hparams_from_file(args.config_dir)
98
-
99
- net_g = SynthesizerTrn(
100
- len(hps.symbols),
101
- hps.data.filter_length // 2 + 1,
102
- hps.train.segment_size // hps.data.hop_length,
103
- n_speakers=hps.data.n_speakers,
104
- **hps.model).to(device)
105
- _ = net_g.eval()
106
-
107
- _ = utils.load_checkpoint(args.model_dir, net_g, None)
108
- speaker_ids = hps.speakers
109
- speakers = list(hps.speakers.keys())
110
- tts_fn = create_tts_fn(net_g, hps, speaker_ids)
111
- vc_fn = create_vc_fn(net_g, hps, speaker_ids)
 
 
 
 
 
 
 
 
 
112
  app = gr.Blocks()
113
  with app:
114
  gr.Markdown(
@@ -119,46 +160,88 @@ if __name__ == "__main__":
119
  "[![Duplicate this Space](https://huggingface.co/datasets/huggingface/badges/raw/main/duplicate-this-space-sm-dark.svg)](https://huggingface.co/spaces/FrankZxShen/vits-fast-finetuning-pcr?duplicate=true)\n\n"
120
  "[![Finetune your own model](https://badgen.net/badge/icon/github?icon=github&label=Finetune%20your%20own%20model)](https://github.com/Plachtaa/VITS-fast-fine-tuning)"
121
  )
122
- with gr.Tab("Text-to-Speech"):
123
- with gr.Row():
124
- with gr.Column():
125
- textbox = gr.TextArea(label="Text",
126
- placeholder="Type your sentence here",
127
- value="新たなキャラを解放できるようになったようですね。", elem_id=f"tts-input")
128
- # select character
129
- char_dropdown = gr.Dropdown(
130
- choices=speakers, value=speakers[0], label='character')
131
- language_dropdown = gr.Dropdown(
132
- choices=lang, value=lang[0], label='language')
133
- duration_slider = gr.Slider(minimum=0.1, maximum=5, value=1, step=0.1,
134
- label='速度 Speed')
135
- with gr.Column():
136
- text_output = gr.Textbox(label="Message")
137
- audio_output = gr.Audio(
138
- label="Output Audio", elem_id="tts-audio")
139
- btn = gr.Button("Generate!")
140
- btn.click(tts_fn,
141
- inputs=[textbox, char_dropdown,
142
- language_dropdown, duration_slider, ],
143
- outputs=[text_output, audio_output])
144
- with gr.Tab("Voice Conversion"):
145
- gr.Markdown("""
146
- 录制或上传声音,并选择要转换的音色。
147
- """)
148
- with gr.Column():
149
- record_audio = gr.Audio(
150
- label="record your voice", source="microphone")
151
- upload_audio = gr.Audio(
152
- label="or upload audio here", source="upload")
153
- source_speaker = gr.Dropdown(
154
- choices=speakers, value=speakers[0], label="source speaker")
155
- target_speaker = gr.Dropdown(
156
- choices=speakers, value=speakers[0], label="target speaker")
157
- with gr.Column():
158
- message_box = gr.Textbox(label="Message")
159
- converted_audio = gr.Audio(label='converted audio')
160
- btn = gr.Button("Convert!")
161
- btn.click(vc_fn, inputs=[source_speaker, target_speaker, record_audio, upload_audio],
162
- outputs=[message_box, converted_audio])
163
- webbrowser.open("http://127.0.0.1:7899")
164
- app.queue(concurrency_count=1, api_open=False).launch(share=args.share)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  "English": "[EN]",
21
  "Mix": "",
22
  }
 
23
 
24
+ limitation = os.getenv("SYSTEM") == "spaces" # limit text and audio length in huggingface spaces
25
 
26
  def get_text(text, hps, is_symbol):
27
+ text_norm = text_to_sequence(text, hps.symbols, [] if is_symbol else hps.data.text_cleaners)
 
28
  if hps.data.add_blank:
29
  text_norm = commons.intersperse(text_norm, 0)
30
  text_norm = LongTensor(text_norm)
31
  return text_norm
32
 
 
33
  def create_tts_fn(model, hps, speaker_ids):
34
+ def tts_fn(text, speaker, language, ns, nsw, speed, is_symbol):
35
+ if limitation:
36
+ text_len = len(re.sub("\[([A-Z]{2})\]", "", text))
37
+ max_len = 150
38
+ if is_symbol:
39
+ max_len *= 3
40
+ if text_len > max_len:
41
+ return "Error: Text is too long", None
42
  if language is not None:
43
  text = language_marks[language] + text + language_marks[language]
44
  speaker_id = speaker_ids[speaker]
45
+ stn_tst = get_text(text, hps, is_symbol)
46
  with no_grad():
47
  x_tst = stn_tst.unsqueeze(0).to(device)
48
  x_tst_lengths = LongTensor([stn_tst.size(0)]).to(device)
49
  sid = LongTensor([speaker_id]).to(device)
50
+ audio = model.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=ns, noise_scale_w=nsw,
51
  length_scale=1.0 / speed)[0][0, 0].data.cpu().float().numpy()
52
  del stn_tst, x_tst, x_tst_lengths, sid
53
  return "Success", (hps.data.sampling_rate, audio)
54
 
55
  return tts_fn
56
 
 
57
  def create_vc_fn(model, hps, speaker_ids):
58
  def vc_fn(original_speaker, target_speaker, record_audio, upload_audio):
59
  input_audio = record_audio if record_audio is not None else upload_audio
 
67
  if len(audio.shape) > 1:
68
  audio = librosa.to_mono(audio.transpose(1, 0))
69
  if sampling_rate != hps.data.sampling_rate:
70
+ audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=hps.data.sampling_rate)
 
71
  with no_grad():
72
  y = torch.FloatTensor(audio)
73
  y = y / max(-y.min(), y.max()) / 0.99
 
86
 
87
  return vc_fn
88
 
89
+ def get_text(text, hps, is_symbol):
90
+ text_norm = text_to_sequence(text, hps.symbols, [] if is_symbol else hps.data.text_cleaners)
91
+ if hps.data.add_blank:
92
+ text_norm = commons.intersperse(text_norm, 0)
93
+ text_norm = LongTensor(text_norm)
94
+ return text_norm
95
+
96
+
97
+ def create_to_symbol_fn(hps):
98
+ def to_symbol_fn(is_symbol_input, input_text, temp_text):
99
+ return (_clean_text(input_text, hps.data.text_cleaners), input_text) if is_symbol_input \
100
+ else (temp_text, temp_text)
101
+
102
+ return to_symbol_fn
103
 
104
+ models_info = [
105
+ {
106
+ "languages": ['日本語', '简体中文', 'English', 'Mix'],
107
+ "description": """
108
+
109
+ 这个模型包含公主连结Re:Dive的161名角色,能合成中日英三语。
110
+
111
+
112
+ """,
113
+ "model_path": "./OUTPUT_MODEL/G_9700.pth",
114
+ "config_path": "./configs/modified_finetune_speaker.json",
115
+ "examples": [['大切な人の誕生日を祝えるって、すごく幸せなことなんですよね。', '佩可莉姆', '日本語', 1, False],
116
+ ['その…この制服、どうですか?', '栞', '日本語', 1, False],
117
+ ['你们全都给我让开!敢挡路的家伙,我统统斩了!', '矛依未', '简体中文', 1, False],
118
+ ['Can you tell me how much the shirt is?', '美咲', 'English', 1, False],
119
+ ['[EN]Excuse me?[EN][JA]お帰りなさい,お兄様![JA]', '咲恋(夏日)', 'Mix', 1, False]],
120
+ }
121
+ ]
122
+
123
+ models_tts = []
124
+ models_vc = []
125
  if __name__ == "__main__":
126
  parser = argparse.ArgumentParser()
127
+ parser.add_argument("--share", action="store_true", default=False, help="share gradio app")
 
 
 
 
 
 
128
  args = parser.parse_args()
129
+ for info in models_info:
130
+ lang = info['languages']
131
+ examples = info['examples']
132
+ config_path = info['config_path']
133
+ model_path = info['model_path']
134
+ description = info['description']
135
+ hps = utils.get_hparams_from_file(config_path)
136
+
137
+ net_g = SynthesizerTrn(
138
+ len(hps.symbols),
139
+ hps.data.filter_length // 2 + 1,
140
+ hps.train.segment_size // hps.data.hop_length,
141
+ n_speakers=hps.data.n_speakers,
142
+ **hps.model).to(device)
143
+ _ = net_g.eval()
144
+
145
+ _ = utils.load_checkpoint(model_path, net_g, None)
146
+ speaker_ids = hps.speakers
147
+ speakers = list(hps.speakers.keys())
148
+ models_tts.append((description, speakers, lang, examples,
149
+ hps.symbols, create_tts_fn(net_g, hps, speaker_ids),
150
+ create_to_symbol_fn(hps)))
151
+ models_vc.append((description, speakers, create_vc_fn(net_g, hps, speaker_ids)))
152
+
153
  app = gr.Blocks()
154
  with app:
155
  gr.Markdown(
 
160
  "[![Duplicate this Space](https://huggingface.co/datasets/huggingface/badges/raw/main/duplicate-this-space-sm-dark.svg)](https://huggingface.co/spaces/FrankZxShen/vits-fast-finetuning-pcr?duplicate=true)\n\n"
161
  "[![Finetune your own model](https://badgen.net/badge/icon/github?icon=github&label=Finetune%20your%20own%20model)](https://github.com/Plachtaa/VITS-fast-fine-tuning)"
162
  )
163
+ gr.Markdown("# TTS&Voice Conversion for Princess Connect! Re:Dive\n\n"
164
+ )
165
+ with gr.Tabs():
166
+ with gr.Tab("TTS"):
167
+ for i, (description, speakers, lang, example, symbols, tts_fn, to_symbol_fn) in enumerate(
168
+ models_tts):
169
+ gr.Markdown(description)
170
+ with gr.Row():
171
+ with gr.Column():
172
+ textbox = gr.TextArea(label="Text",
173
+ placeholder="Type your sentence here (Maximum 150 words)",
174
+ value="新たなキャラを解放できるようになったようですね。", elem_id=f"tts-input")
175
+ with gr.Accordion(label="Phoneme Input", open=False):
176
+ temp_text_var = gr.Variable()
177
+ symbol_input = gr.Checkbox(value=False, label="Symbol input")
178
+ symbol_list = gr.Dataset(label="Symbol list", components=[textbox],
179
+ samples=[[x] for x in symbols],
180
+ elem_id=f"symbol-list")
181
+ symbol_list_json = gr.Json(value=symbols, visible=False)
182
+ symbol_input.change(to_symbol_fn,
183
+ [symbol_input, textbox, temp_text_var],
184
+ [textbox, temp_text_var])
185
+ symbol_list.click(None, [symbol_list, symbol_list_json], textbox,
186
+ _js=f"""
187
+ (i, symbols, text) => {{
188
+ let root = document.querySelector("body > gradio-app");
189
+ if (root.shadowRoot != null)
190
+ root = root.shadowRoot;
191
+ let text_input = root.querySelector("#tts-input").querySelector("textarea");
192
+ let startPos = text_input.selectionStart;
193
+ let endPos = text_input.selectionEnd;
194
+ let oldTxt = text_input.value;
195
+ let result = oldTxt.substring(0, startPos) + symbols[i] + oldTxt.substring(endPos);
196
+ text_input.value = result;
197
+ let x = window.scrollX, y = window.scrollY;
198
+ text_input.focus();
199
+ text_input.selectionStart = startPos + symbols[i].length;
200
+ text_input.selectionEnd = startPos + symbols[i].length;
201
+ text_input.blur();
202
+ window.scrollTo(x, y);
203
+ text = text_input.value;
204
+ return text;
205
+ }}""")
206
+ # select character
207
+ char_dropdown = gr.Dropdown(choices=speakers, value=speakers[0], label='character')
208
+ language_dropdown = gr.Dropdown(choices=lang, value=lang[0], label='language')
209
+ ns = gr.Slider(label="noise_scale", minimum=0.1, maximum=1.0, step=0.1, value=0.6, interactive=True)
210
+ nsw = gr.Slider(label="noise_scale_w", minimum=0.1, maximum=1.0, step=0.1, value=0.668, interactive=True)
211
+ duration_slider = gr.Slider(minimum=0.1, maximum=5, value=1, step=0.1,
212
+ label='速度 Speed')
213
+ with gr.Column():
214
+ text_output = gr.Textbox(label="Message")
215
+ audio_output = gr.Audio(label="Output Audio", elem_id="tts-audio")
216
+ btn = gr.Button("Generate!")
217
+ btn.click(tts_fn,
218
+ inputs=[textbox, char_dropdown, language_dropdown, ns, nsw, duration_slider,
219
+ symbol_input],
220
+ outputs=[text_output, audio_output])
221
+ gr.Examples(
222
+ examples=example,
223
+ inputs=[textbox, char_dropdown, language_dropdown,
224
+ duration_slider, symbol_input],
225
+ outputs=[text_output, audio_output],
226
+ fn=tts_fn
227
+ )
228
+ with gr.Tab("Voice Conversion"):
229
+ for i, (description, speakers, vc_fn) in enumerate(
230
+ models_vc):
231
+ gr.Markdown("""
232
+ 录制或上传声音,并选择要转换的音色。
233
+ """)
234
+ with gr.Column():
235
+ record_audio = gr.Audio(label="record your voice", source="microphone")
236
+ upload_audio = gr.Audio(label="or upload audio here", source="upload")
237
+ source_speaker = gr.Dropdown(choices=speakers, value=speakers[0], label="source speaker")
238
+ target_speaker = gr.Dropdown(choices=speakers, value=speakers[0], label="target speaker")
239
+ with gr.Column():
240
+ message_box = gr.Textbox(label="Message")
241
+ converted_audio = gr.Audio(label='converted audio')
242
+ btn = gr.Button("Convert!")
243
+ btn.click(vc_fn, inputs=[source_speaker, target_speaker, record_audio, upload_audio],
244
+ outputs=[message_box, converted_audio])
245
+
246
+ app.queue(concurrency_count=3).launch(show_api=False, share=args.share)
247
+