wang0507 commited on
Commit
72298b2
1 Parent(s): ab56e96

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +568 -137
app.py CHANGED
@@ -1,151 +1,582 @@
1
- from pathlib import Path
2
- from typing import List, Dict, Tuple
3
- import matplotlib.colors as mpl_colors
4
-
5
- import pandas as pd
6
- import seaborn as sns
7
- import shinyswatch
8
-
9
- from shiny import App, Inputs, Outputs, Session, reactive, render, req, ui
10
-
11
- sns.set_theme()
12
-
13
- www_dir = Path(__file__).parent.resolve() / "www"
14
-
15
- df = pd.read_csv(Path(__file__).parent / "penguins.csv", na_values="NA")
16
- numeric_cols: List[str] = df.select_dtypes(include=["float64"]).columns.tolist()
17
- species: List[str] = df["Species"].unique().tolist()
18
- species.sort()
19
-
20
- app_ui = ui.page_fillable(
21
- shinyswatch.theme.minty(),
22
- ui.layout_sidebar(
23
- ui.sidebar(
24
- # Artwork by @allison_horst
25
- ui.input_selectize(
26
- "xvar",
27
- "X variable",
28
- numeric_cols,
29
- selected="Bill Length (mm)",
30
- ),
31
- ui.input_selectize(
32
- "yvar",
33
- "Y variable",
34
- numeric_cols,
35
- selected="Bill Depth (mm)",
36
- ),
37
- ui.input_checkbox_group(
38
- "species", "Filter by species", species, selected=species
39
- ),
40
- ui.hr(),
41
- ui.input_switch("by_species", "Show species", value=True),
42
- ui.input_switch("show_margins", "Show marginal plots", value=True),
43
- ),
44
- ui.output_ui("value_boxes"),
45
- ui.output_plot("scatter", fill=True),
46
- ui.help_text(
47
- "Artwork by ",
48
- ui.a("@allison_horst", href="https://twitter.com/allison_horst"),
49
- class_="text-end",
50
- ),
51
- ),
52
  )
 
 
 
 
 
 
53
 
 
 
 
54
 
55
- def server(input: Inputs, output: Outputs, session: Session):
56
- @reactive.Calc
57
- def filtered_df() -> pd.DataFrame:
58
- """Returns a Pandas data frame that includes only the desired rows"""
59
-
60
- # This calculation "req"uires that at least one species is selected
61
- req(len(input.species()) > 0)
62
-
63
- # Filter the rows so we only include the desired species
64
- return df[df["Species"].isin(input.species())]
65
-
66
- @output
67
- @render.plot
68
- def scatter():
69
- """Generates a plot for Shiny to display to the user"""
70
-
71
- # The plotting function to use depends on whether margins are desired
72
- plotfunc = sns.jointplot if input.show_margins() else sns.scatterplot
73
-
74
- plotfunc(
75
- data=filtered_df(),
76
- x=input.xvar(),
77
- y=input.yvar(),
78
- palette=palette,
79
- hue="Species" if input.by_species() else None,
80
- hue_order=species,
81
- legend=False,
82
- )
83
-
84
- @output
85
- @render.ui
86
- def value_boxes():
87
- df = filtered_df()
88
-
89
- def penguin_value_box(title: str, count: int, bgcol: str, showcase_img: str):
90
- return ui.value_box(
91
- title,
92
- count,
93
- {"class_": "pt-1 pb-0"},
94
- showcase=ui.fill.as_fill_item(
95
- ui.tags.img(
96
- {"style": "object-fit:contain;"},
97
- src=showcase_img,
98
- )
99
- ),
100
- theme_color=None,
101
- style=f"background-color: {bgcol};",
102
- )
103
 
104
- if not input.by_species():
105
- return penguin_value_box(
106
- "Penguins",
107
- len(df.index),
108
- bg_palette["default"],
109
- # Artwork by @allison_horst
110
- showcase_img="penguins.png",
111
- )
112
 
113
- value_boxes = [
114
- penguin_value_box(
115
- name,
116
- len(df[df["Species"] == name]),
117
- bg_palette[name],
118
- # Artwork by @allison_horst
119
- showcase_img=f"{name}.png",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  )
121
- for name in species
122
- # Only include boxes for _selected_ species
123
- if name in input.species()
 
 
 
 
 
 
 
 
 
 
 
 
124
  ]
 
125
 
126
- return ui.layout_column_wrap(*value_boxes, width = 1 / len(value_boxes))
127
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
 
129
- # "darkorange", "purple", "cyan4"
130
- colors = [[255, 140, 0], [160, 32, 240], [0, 139, 139]]
131
- colors = [(r / 255.0, g / 255.0, b / 255.0) for r, g, b in colors]
 
 
132
 
133
- palette: Dict[str, Tuple[float, float, float]] = {
134
- "Adelie": colors[0],
135
- "Chinstrap": colors[1],
136
- "Gentoo": colors[2],
137
- "default": sns.color_palette()[0], # type: ignore
138
- }
 
 
 
 
 
 
 
139
 
140
- bg_palette = {}
141
- # Use `sns.set_style("whitegrid")` to help find approx alpha value
142
- for name, col in palette.items():
143
- # Adjusted n_colors until `axe` accessibility did not complain about color contrast
144
- bg_palette[name] = mpl_colors.to_hex(sns.light_palette(col, n_colors=7)[1]) # type: ignore
 
 
 
 
145
 
 
 
 
146
 
147
- app = App(
148
- app_ui,
149
- server,
150
- static_assets=str(www_dir),
151
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import pathlib
4
+ import time
5
+ import tempfile
6
+ import platform
7
+ import gc
8
+ if platform.system().lower() == 'windows':
9
+ temp = pathlib.PosixPath
10
+ pathlib.PosixPath = pathlib.WindowsPath
11
+ elif platform.system().lower() == 'linux':
12
+ temp = pathlib.WindowsPath
13
+ pathlib.WindowsPath = pathlib.PosixPath
14
+ os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
15
+
16
+ import langid
17
+ langid.set_languages(['en', 'zh', 'ja'])
18
+
19
+ import torch
20
+ import torchaudio
21
+
22
+ import numpy as np
23
+
24
+ from data.tokenizer import (
25
+ AudioTokenizer,
26
+ tokenize_audio,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  )
28
+ from data.collation import get_text_token_collater
29
+ from models.vallex import VALLE
30
+ from utils.g2p import PhonemeBpeTokenizer
31
+ from descriptions import *
32
+ from macros import *
33
+ from examples import *
34
 
35
+ import gradio as gr
36
+ from vocos import Vocos
37
+ from transformers import WhisperProcessor, WhisperForConditionalGeneration
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
 
 
 
 
 
 
 
 
40
 
41
+ torch._C._jit_set_profiling_executor(False)
42
+ torch._C._jit_set_profiling_mode(False)
43
+ torch._C._set_graph_executor_optimize(False)
44
+
45
+ text_tokenizer = PhonemeBpeTokenizer(tokenizer_path="./utils/g2p/bpe_69.json")
46
+ text_collater = get_text_token_collater()
47
+
48
+ device = torch.device("cpu")
49
+ if torch.cuda.is_available():
50
+ device = torch.device("cuda", 0)
51
+
52
+ # VALL-E-X model
53
+ model = VALLE(
54
+ N_DIM,
55
+ NUM_HEAD,
56
+ NUM_LAYERS,
57
+ norm_first=True,
58
+ add_prenet=False,
59
+ prefix_mode=PREFIX_MODE,
60
+ share_embedding=True,
61
+ nar_scale_factor=1.0,
62
+ prepend_bos=True,
63
+ num_quantizers=NUM_QUANTIZERS,
64
+ ).to(device)
65
+ checkpoint = torch.load("./epoch-10.pt", map_location='cpu')
66
+ missing_keys, unexpected_keys = model.load_state_dict(
67
+ checkpoint["model"], strict=True
68
+ )
69
+ del checkpoint
70
+ assert not missing_keys
71
+ model.eval()
72
+
73
+ # Encodec model
74
+ audio_tokenizer = AudioTokenizer(device)
75
+
76
+ # Vocos decoder
77
+ vocos = Vocos.from_pretrained('charactr/vocos-encodec-24khz').to(device)
78
+
79
+ # ASR
80
+ whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-medium")
81
+ whisper = WhisperForConditionalGeneration.from_pretrained("openai/whisper-medium").to(device)
82
+ whisper.config.forced_decoder_ids = None
83
+
84
+ # Voice Presets
85
+ preset_list = os.walk("./presets/").__next__()[2]
86
+ preset_list = [preset[:-4] for preset in preset_list if preset.endswith(".npz")]
87
+
88
+ def clear_prompts():
89
+ try:
90
+ path = tempfile.gettempdir()
91
+ for eachfile in os.listdir(path):
92
+ filename = os.path.join(path, eachfile)
93
+ if os.path.isfile(filename) and filename.endswith(".npz"):
94
+ lastmodifytime = os.stat(filename).st_mtime
95
+ endfiletime = time.time() - 60
96
+ if endfiletime > lastmodifytime:
97
+ os.remove(filename)
98
+ del path, filename, lastmodifytime, endfiletime
99
+ gc.collect()
100
+ except:
101
+ return
102
+
103
+ def transcribe_one(wav, sr):
104
+ if sr != 16000:
105
+ wav4trans = torchaudio.transforms.Resample(sr, 16000)(wav)
106
+ else:
107
+ wav4trans = wav
108
+
109
+ input_features = whisper_processor(wav4trans.squeeze(0), sampling_rate=16000, return_tensors="pt").input_features
110
+
111
+ # generate token ids
112
+ predicted_ids = whisper.generate(input_features.to(device))
113
+ lang = whisper_processor.batch_decode(predicted_ids[:, 1])[0].strip("<|>")
114
+ # decode token ids to text
115
+ text_pr = whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
116
+
117
+ # print the recognized text
118
+ print(text_pr)
119
+
120
+ if text_pr.strip(" ")[-1] not in "?!.,。,?!。、":
121
+ text_pr += "."
122
+
123
+ # delete all variables
124
+ del wav4trans, input_features, predicted_ids
125
+ gc.collect()
126
+ return lang, text_pr
127
+
128
+ def make_npz_prompt(name, uploaded_audio, recorded_audio, transcript_content):
129
+ clear_prompts()
130
+ audio_prompt = uploaded_audio if uploaded_audio is not None else recorded_audio
131
+ sr, wav_pr = audio_prompt
132
+ if len(wav_pr) / sr > 15:
133
+ return "Rejected, Audio too long (should be less than 15 seconds)", None
134
+ if not isinstance(wav_pr, torch.FloatTensor):
135
+ wav_pr = torch.FloatTensor(wav_pr)
136
+ if wav_pr.abs().max() > 1:
137
+ wav_pr /= wav_pr.abs().max()
138
+ if wav_pr.size(-1) == 2:
139
+ wav_pr = wav_pr[:, 0]
140
+ if wav_pr.ndim == 1:
141
+ wav_pr = wav_pr.unsqueeze(0)
142
+ assert wav_pr.ndim and wav_pr.size(0) == 1
143
+
144
+ if transcript_content == "":
145
+ lang_pr, text_pr = transcribe_one(wav_pr, sr)
146
+ lang_token = lang2token[lang_pr]
147
+ text_pr = lang_token + text_pr + lang_token
148
+ else:
149
+ lang_pr = langid.classify(str(transcript_content))[0]
150
+ lang_token = lang2token[lang_pr]
151
+ transcript_content = transcript_content.replace("\n", "")
152
+ text_pr = f"{lang_token}{str(transcript_content)}{lang_token}"
153
+ # tokenize audio
154
+ encoded_frames = tokenize_audio(audio_tokenizer, (wav_pr, sr))
155
+ audio_tokens = encoded_frames[0][0].transpose(2, 1).cpu().numpy()
156
+
157
+ # tokenize text
158
+ phonemes, _ = text_tokenizer.tokenize(text=f"{text_pr}".strip())
159
+ text_tokens, enroll_x_lens = text_collater(
160
+ [
161
+ phonemes
162
+ ]
163
+ )
164
+
165
+ message = f"Detected language: {lang_pr}\n Detected text {text_pr}\n"
166
+ if lang_pr not in ['ja', 'zh', 'en']:
167
+ return f"Prompt can only made with one of model-supported languages, got {lang_pr} instead", None
168
+
169
+ # save as npz file
170
+ np.savez(os.path.join(tempfile.gettempdir(), f"{name}.npz"),
171
+ audio_tokens=audio_tokens, text_tokens=text_tokens, lang_code=lang2code[lang_pr])
172
+
173
+ # delete all variables
174
+ del audio_tokens, text_tokens, phonemes, lang_pr, text_pr, wav_pr, sr, uploaded_audio, recorded_audio
175
+ gc.collect()
176
+ return message, os.path.join(tempfile.gettempdir(), f"{name}.npz")
177
+
178
+
179
+ @torch.no_grad()
180
+ def infer_from_audio(text, language, accent, audio_prompt, record_audio_prompt, transcript_content):
181
+ if len(text) > 150:
182
+ return "Rejected, Text too long (should be less than 150 characters)", None
183
+ if audio_prompt is None and record_audio_prompt is None:
184
+ audio_prompts = torch.zeros([1, 0, NUM_QUANTIZERS]).type(torch.int32).to(device)
185
+ text_prompts = torch.zeros([1, 0]).type(torch.int32)
186
+ lang_pr = 'en'
187
+ text_pr = ""
188
+ enroll_x_lens = 0
189
+ wav_pr, sr = None, None
190
+ else:
191
+ audio_prompt = audio_prompt if audio_prompt is not None else record_audio_prompt
192
+ sr, wav_pr = audio_prompt
193
+ if len(wav_pr) / sr > 15:
194
+ return "Rejected, Audio too long (should be less than 15 seconds)", None
195
+ if not isinstance(wav_pr, torch.FloatTensor):
196
+ wav_pr = torch.FloatTensor(wav_pr)
197
+ if wav_pr.abs().max() > 1:
198
+ wav_pr /= wav_pr.abs().max()
199
+ if wav_pr.size(-1) == 2:
200
+ wav_pr = wav_pr[:, 0]
201
+ if wav_pr.ndim == 1:
202
+ wav_pr = wav_pr.unsqueeze(0)
203
+ assert wav_pr.ndim and wav_pr.size(0) == 1
204
+
205
+ if transcript_content == "":
206
+ lang_pr, text_pr = transcribe_one(wav_pr, sr)
207
+ lang_token = lang2token[lang_pr]
208
+ text_pr = lang_token + text_pr + lang_token
209
+ else:
210
+ lang_pr = langid.classify(str(transcript_content))[0]
211
+ text_pr = transcript_content.replace("\n", "")
212
+ if lang_pr not in ['ja', 'zh', 'en']:
213
+ return f"Reference audio must be a speech of one of model-supported languages, got {lang_pr} instead", None
214
+ lang_token = lang2token[lang_pr]
215
+ text_pr = lang_token + text_pr + lang_token
216
+
217
+ # tokenize audio
218
+ encoded_frames = tokenize_audio(audio_tokenizer, (wav_pr, sr))
219
+ audio_prompts = encoded_frames[0][0].transpose(2, 1).to(device)
220
+
221
+ enroll_x_lens = None
222
+ if text_pr:
223
+ text_prompts, _ = text_tokenizer.tokenize(text=f"{text_pr}".strip())
224
+ text_prompts, enroll_x_lens = text_collater(
225
+ [
226
+ text_prompts
227
+ ]
228
  )
229
+
230
+ if language == 'auto-detect':
231
+ lang_token = lang2token[langid.classify(text)[0]]
232
+ else:
233
+ lang_token = langdropdown2token[language]
234
+ lang = token2lang[lang_token]
235
+ text = text.replace("\n", "")
236
+ text = lang_token + text + lang_token
237
+
238
+ # tokenize text
239
+ logging.info(f"synthesize text: {text}")
240
+ phone_tokens, langs = text_tokenizer.tokenize(text=f"_{text}".strip())
241
+ text_tokens, text_tokens_lens = text_collater(
242
+ [
243
+ phone_tokens
244
  ]
245
+ )
246
 
 
247
 
248
+ text_tokens = torch.cat([text_prompts, text_tokens], dim=-1)
249
+ text_tokens_lens += enroll_x_lens
250
+ lang = lang if accent == "no-accent" else token2lang[langdropdown2token[accent]]
251
+ encoded_frames = model.inference(
252
+ text_tokens.to(device),
253
+ text_tokens_lens.to(device),
254
+ audio_prompts,
255
+ enroll_x_lens=enroll_x_lens,
256
+ top_k=-100,
257
+ temperature=1,
258
+ prompt_language=lang_pr,
259
+ text_language=langs if accent == "no-accent" else lang,
260
+ )
261
+ # Decode with Vocos
262
+ frames = encoded_frames.permute(2,0,1)
263
+ features = vocos.codes_to_features(frames)
264
+ samples = vocos.decode(features, bandwidth_id=torch.tensor([2], device=device))
265
 
266
+ message = f"text prompt: {text_pr}\nsythesized text: {text}"
267
+ # delete all variables
268
+ del audio_prompts, text_tokens, text_prompts, phone_tokens, encoded_frames, wav_pr, sr, audio_prompt, record_audio_prompt, transcript_content
269
+ gc.collect()
270
+ return message, (24000, samples.squeeze(0).cpu().numpy())
271
 
272
+ @torch.no_grad()
273
+ def infer_from_prompt(text, language, accent, preset_prompt, prompt_file):
274
+ if len(text) > 150:
275
+ return "Rejected, Text too long (should be less than 150 characters)", None
276
+ clear_prompts()
277
+ # text to synthesize
278
+ if language == 'auto-detect':
279
+ lang_token = lang2token[langid.classify(text)[0]]
280
+ else:
281
+ lang_token = langdropdown2token[language]
282
+ lang = token2lang[lang_token]
283
+ text = text.replace("\n", "")
284
+ text = lang_token + text + lang_token
285
 
286
+ # load prompt
287
+ if prompt_file is not None:
288
+ prompt_data = np.load(prompt_file.name)
289
+ else:
290
+ prompt_data = np.load(os.path.join("./presets/", f"{preset_prompt}.npz"))
291
+ audio_prompts = prompt_data['audio_tokens']
292
+ text_prompts = prompt_data['text_tokens']
293
+ lang_pr = prompt_data['lang_code']
294
+ lang_pr = code2lang[int(lang_pr)]
295
 
296
+ # numpy to tensor
297
+ audio_prompts = torch.tensor(audio_prompts).type(torch.int32).to(device)
298
+ text_prompts = torch.tensor(text_prompts).type(torch.int32)
299
 
300
+ enroll_x_lens = text_prompts.shape[-1]
301
+ logging.info(f"synthesize text: {text}")
302
+ phone_tokens, langs = text_tokenizer.tokenize(text=f"_{text}".strip())
303
+ text_tokens, text_tokens_lens = text_collater(
304
+ [
305
+ phone_tokens
306
+ ]
307
+ )
308
+ text_tokens = torch.cat([text_prompts, text_tokens], dim=-1)
309
+ text_tokens_lens += enroll_x_lens
310
+ # accent control
311
+ lang = lang if accent == "no-accent" else token2lang[langdropdown2token[accent]]
312
+ encoded_frames = model.inference(
313
+ text_tokens.to(device),
314
+ text_tokens_lens.to(device),
315
+ audio_prompts,
316
+ enroll_x_lens=enroll_x_lens,
317
+ top_k=-100,
318
+ temperature=1,
319
+ prompt_language=lang_pr,
320
+ text_language=langs if accent == "no-accent" else lang,
321
+ )
322
+ # Decode with Vocos
323
+ frames = encoded_frames.permute(2,0,1)
324
+ features = vocos.codes_to_features(frames)
325
+ samples = vocos.decode(features, bandwidth_id=torch.tensor([2], device=device))
326
+
327
+ message = f"sythesized text: {text}"
328
+
329
+ # delete all variables
330
+ del audio_prompts, text_tokens, text_prompts, phone_tokens, encoded_frames, prompt_file, preset_prompt
331
+ gc.collect()
332
+ return message, (24000, samples.squeeze(0).cpu().numpy())
333
+
334
+
335
+ from utils.sentence_cutter import split_text_into_sentences
336
+ @torch.no_grad()
337
+ def infer_long_text(text, preset_prompt, prompt=None, language='auto', accent='no-accent'):
338
+ """
339
+ For long audio generation, two modes are available.
340
+ fixed-prompt: This mode will keep using the same prompt the user has provided, and generate audio sentence by sentence.
341
+ sliding-window: This mode will use the last sentence as the prompt for the next sentence, but has some concern on speaker maintenance.
342
+ """
343
+ if len(text) > 1000:
344
+ return "Rejected, Text too long (should be less than 1000 characters)", None
345
+ mode = 'fixed-prompt'
346
+ if (prompt is None or prompt == "") and preset_prompt == "":
347
+ mode = 'sliding-window' # If no prompt is given, use sliding-window mode
348
+ sentences = split_text_into_sentences(text)
349
+ # detect language
350
+ if language == "auto-detect":
351
+ language = langid.classify(text)[0]
352
+ else:
353
+ language = token2lang[langdropdown2token[language]]
354
+
355
+ # if initial prompt is given, encode it
356
+ if prompt is not None and prompt != "":
357
+ # load prompt
358
+ prompt_data = np.load(prompt.name)
359
+ audio_prompts = prompt_data['audio_tokens']
360
+ text_prompts = prompt_data['text_tokens']
361
+ lang_pr = prompt_data['lang_code']
362
+ lang_pr = code2lang[int(lang_pr)]
363
+
364
+ # numpy to tensor
365
+ audio_prompts = torch.tensor(audio_prompts).type(torch.int32).to(device)
366
+ text_prompts = torch.tensor(text_prompts).type(torch.int32)
367
+ elif preset_prompt is not None and preset_prompt != "":
368
+ prompt_data = np.load(os.path.join("./presets/", f"{preset_prompt}.npz"))
369
+ audio_prompts = prompt_data['audio_tokens']
370
+ text_prompts = prompt_data['text_tokens']
371
+ lang_pr = prompt_data['lang_code']
372
+ lang_pr = code2lang[int(lang_pr)]
373
+
374
+ # numpy to tensor
375
+ audio_prompts = torch.tensor(audio_prompts).type(torch.int32).to(device)
376
+ text_prompts = torch.tensor(text_prompts).type(torch.int32)
377
+ else:
378
+ audio_prompts = torch.zeros([1, 0, NUM_QUANTIZERS]).type(torch.int32).to(device)
379
+ text_prompts = torch.zeros([1, 0]).type(torch.int32)
380
+ lang_pr = language if language != 'mix' else 'en'
381
+ if mode == 'fixed-prompt':
382
+ complete_tokens = torch.zeros([1, NUM_QUANTIZERS, 0]).type(torch.LongTensor).to(device)
383
+ for text in sentences:
384
+ text = text.replace("\n", "").strip(" ")
385
+ if text == "":
386
+ continue
387
+ lang_token = lang2token[language]
388
+ lang = token2lang[lang_token]
389
+ text = lang_token + text + lang_token
390
+
391
+ enroll_x_lens = text_prompts.shape[-1]
392
+ logging.info(f"synthesize text: {text}")
393
+ phone_tokens, langs = text_tokenizer.tokenize(text=f"_{text}".strip())
394
+ text_tokens, text_tokens_lens = text_collater(
395
+ [
396
+ phone_tokens
397
+ ]
398
+ )
399
+ text_tokens = torch.cat([text_prompts, text_tokens], dim=-1)
400
+ text_tokens_lens += enroll_x_lens
401
+ # accent control
402
+ lang = lang if accent == "no-accent" else token2lang[langdropdown2token[accent]]
403
+ encoded_frames = model.inference(
404
+ text_tokens.to(device),
405
+ text_tokens_lens.to(device),
406
+ audio_prompts,
407
+ enroll_x_lens=enroll_x_lens,
408
+ top_k=-100,
409
+ temperature=1,
410
+ prompt_language=lang_pr,
411
+ text_language=langs if accent == "no-accent" else lang,
412
+ )
413
+ complete_tokens = torch.cat([complete_tokens, encoded_frames.transpose(2, 1)], dim=-1)
414
+ # Decode with Vocos
415
+ frames = complete_tokens.permute(1, 0, 2)
416
+ features = vocos.codes_to_features(frames)
417
+ samples = vocos.decode(features, bandwidth_id=torch.tensor([2], device=device))
418
+
419
+ message = f"Cut into {len(sentences)} sentences"
420
+ return message, (24000, samples.squeeze(0).cpu().numpy())
421
+ elif mode == "sliding-window":
422
+ complete_tokens = torch.zeros([1, NUM_QUANTIZERS, 0]).type(torch.LongTensor).to(device)
423
+ original_audio_prompts = audio_prompts
424
+ original_text_prompts = text_prompts
425
+ for text in sentences:
426
+ text = text.replace("\n", "").strip(" ")
427
+ if text == "":
428
+ continue
429
+ lang_token = lang2token[language]
430
+ lang = token2lang[lang_token]
431
+ text = lang_token + text + lang_token
432
+
433
+ enroll_x_lens = text_prompts.shape[-1]
434
+ logging.info(f"synthesize text: {text}")
435
+ phone_tokens, langs = text_tokenizer.tokenize(text=f"_{text}".strip())
436
+ text_tokens, text_tokens_lens = text_collater(
437
+ [
438
+ phone_tokens
439
+ ]
440
+ )
441
+ text_tokens = torch.cat([text_prompts, text_tokens], dim=-1)
442
+ text_tokens_lens += enroll_x_lens
443
+ # accent control
444
+ lang = lang if accent == "no-accent" else token2lang[langdropdown2token[accent]]
445
+ encoded_frames = model.inference(
446
+ text_tokens.to(device),
447
+ text_tokens_lens.to(device),
448
+ audio_prompts,
449
+ enroll_x_lens=enroll_x_lens,
450
+ top_k=-100,
451
+ temperature=1,
452
+ prompt_language=lang_pr,
453
+ text_language=langs if accent == "no-accent" else lang,
454
+ )
455
+ complete_tokens = torch.cat([complete_tokens, encoded_frames.transpose(2, 1)], dim=-1)
456
+ if torch.rand(1) < 1.0:
457
+ audio_prompts = encoded_frames[:, :, -NUM_QUANTIZERS:]
458
+ text_prompts = text_tokens[:, enroll_x_lens:]
459
+ else:
460
+ audio_prompts = original_audio_prompts
461
+ text_prompts = original_text_prompts
462
+ # Decode with Vocos
463
+ frames = complete_tokens.permute(1, 0, 2)
464
+ features = vocos.codes_to_features(frames)
465
+ samples = vocos.decode(features, bandwidth_id=torch.tensor([2], device=device))
466
+
467
+ message = f"Cut into {len(sentences)} sentences"
468
+
469
+ return message, (24000, samples.squeeze(0).cpu().numpy())
470
+ else:
471
+ raise ValueError(f"No such mode {mode}")
472
+
473
+ app = gr.Blocks()
474
+ with app:
475
+ gr.Markdown(top_md)
476
+ with gr.Tab("Infer from audio"):
477
+ gr.Markdown(infer_from_audio_md)
478
+ with gr.Row():
479
+ with gr.Column():
480
+
481
+ textbox = gr.TextArea(label="Text",
482
+ placeholder="Type your sentence here",
483
+ value="Welcome back, Master. What can I do for you today?", elem_id=f"tts-input")
484
+ language_dropdown = gr.Dropdown(choices=['auto-detect', 'English', '中文', '日本語'], value='auto-detect', label='language')
485
+ accent_dropdown = gr.Dropdown(choices=['no-accent', 'English', '中文', '日本語'], value='no-accent', label='accent')
486
+ textbox_transcript = gr.TextArea(label="Transcript",
487
+ placeholder="Write transcript here. (leave empty to use whisper)",
488
+ value="", elem_id=f"prompt-name")
489
+ upload_audio_prompt = gr.Audio(label='uploaded audio prompt', source='upload', interactive=True)
490
+ record_audio_prompt = gr.Audio(label='recorded audio prompt', source='microphone', interactive=True)
491
+ with gr.Column():
492
+ text_output = gr.Textbox(label="Message")
493
+ audio_output = gr.Audio(label="Output Audio", elem_id="tts-audio")
494
+ btn = gr.Button("Generate!")
495
+ btn.click(infer_from_audio,
496
+ inputs=[textbox, language_dropdown, accent_dropdown, upload_audio_prompt, record_audio_prompt, textbox_transcript],
497
+ outputs=[text_output, audio_output])
498
+ textbox_mp = gr.TextArea(label="Prompt name",
499
+ placeholder="Name your prompt here",
500
+ value="prompt_1", elem_id=f"prompt-name")
501
+ btn_mp = gr.Button("Make prompt!")
502
+ prompt_output = gr.File(interactive=False)
503
+ btn_mp.click(make_npz_prompt,
504
+ inputs=[textbox_mp, upload_audio_prompt, record_audio_prompt, textbox_transcript],
505
+ outputs=[text_output, prompt_output])
506
+ gr.Examples(examples=infer_from_audio_examples,
507
+ inputs=[textbox, language_dropdown, accent_dropdown, upload_audio_prompt, record_audio_prompt, textbox_transcript],
508
+ outputs=[text_output, audio_output],
509
+ fn=infer_from_audio,
510
+ cache_examples=False,)
511
+ with gr.Tab("Make prompt"):
512
+ gr.Markdown(make_prompt_md)
513
+ with gr.Row():
514
+ with gr.Column():
515
+ textbox2 = gr.TextArea(label="Prompt name",
516
+ placeholder="Name your prompt here",
517
+ value="prompt_1", elem_id=f"prompt-name")
518
+ # 添加选择语言和输入台本的地方
519
+ textbox_transcript2 = gr.TextArea(label="Transcript",
520
+ placeholder="Write transcript here. (leave empty to use whisper)",
521
+ value="", elem_id=f"prompt-name")
522
+ upload_audio_prompt_2 = gr.Audio(label='uploaded audio prompt', source='upload', interactive=True)
523
+ record_audio_prompt_2 = gr.Audio(label='recorded audio prompt', source='microphone', interactive=True)
524
+ with gr.Column():
525
+ text_output_2 = gr.Textbox(label="Message")
526
+ prompt_output_2 = gr.File(interactive=False)
527
+ btn_2 = gr.Button("Make!")
528
+ btn_2.click(make_npz_prompt,
529
+ inputs=[textbox2, upload_audio_prompt_2, record_audio_prompt_2, textbox_transcript2],
530
+ outputs=[text_output_2, prompt_output_2])
531
+ gr.Examples(examples=make_npz_prompt_examples,
532
+ inputs=[textbox2, upload_audio_prompt_2, record_audio_prompt_2, textbox_transcript2],
533
+ outputs=[text_output_2, prompt_output_2],
534
+ fn=make_npz_prompt,
535
+ cache_examples=False,)
536
+ with gr.Tab("Infer from prompt"):
537
+ gr.Markdown(infer_from_prompt_md)
538
+ with gr.Row():
539
+ with gr.Column():
540
+ textbox_3 = gr.TextArea(label="Text",
541
+ placeholder="Type your sentence here",
542
+ value="Welcome back, Master. What can I do for you today?", elem_id=f"tts-input")
543
+ language_dropdown_3 = gr.Dropdown(choices=['auto-detect', 'English', '中文', '日本語', 'Mix'], value='auto-detect',
544
+ label='language')
545
+ accent_dropdown_3 = gr.Dropdown(choices=['no-accent', 'English', '中文', '日本語'], value='no-accent',
546
+ label='accent')
547
+ preset_dropdown_3 = gr.Dropdown(choices=preset_list, value=None, label='Voice preset')
548
+ prompt_file = gr.File(file_count='single', file_types=['.npz'], interactive=True)
549
+ with gr.Column():
550
+ text_output_3 = gr.Textbox(label="Message")
551
+ audio_output_3 = gr.Audio(label="Output Audio", elem_id="tts-audio")
552
+ btn_3 = gr.Button("Generate!")
553
+ btn_3.click(infer_from_prompt,
554
+ inputs=[textbox_3, language_dropdown_3, accent_dropdown_3, preset_dropdown_3, prompt_file],
555
+ outputs=[text_output_3, audio_output_3])
556
+ gr.Examples(examples=infer_from_prompt_examples,
557
+ inputs=[textbox_3, language_dropdown_3, accent_dropdown_3, preset_dropdown_3, prompt_file],
558
+ outputs=[text_output_3, audio_output_3],
559
+ fn=infer_from_prompt,
560
+ cache_examples=False,)
561
+ with gr.Tab("Infer long text"):
562
+ gr.Markdown(long_text_md)
563
+ with gr.Row():
564
+ with gr.Column():
565
+ textbox_4 = gr.TextArea(label="Text",
566
+ placeholder="Type your sentence here",
567
+ value=long_text_example, elem_id=f"tts-input")
568
+ language_dropdown_4 = gr.Dropdown(choices=['auto-detect', 'English', '中文', '日本語'], value='auto-detect',
569
+ label='language')
570
+ accent_dropdown_4 = gr.Dropdown(choices=['no-accent', 'English', '中文', '日本語'], value='no-accent',
571
+ label='accent')
572
+ preset_dropdown_4 = gr.Dropdown(choices=preset_list, value=None, label='Voice preset')
573
+ prompt_file_4 = gr.File(file_count='single', file_types=['.npz'], interactive=True)
574
+ with gr.Column():
575
+ text_output_4 = gr.TextArea(label="Message")
576
+ audio_output_4 = gr.Audio(label="Output Audio", elem_id="tts-audio")
577
+ btn_4 = gr.Button("Generate!")
578
+ btn_4.click(infer_long_text,
579
+ inputs=[textbox_4, preset_dropdown_4, prompt_file_4, language_dropdown_4, accent_dropdown_4],
580
+ outputs=[text_output_4, audio_output_4])
581
+
582
+ app.launch()