Nông Văn Thắng commited on
Commit
b520cd4
·
1 Parent(s): 13b062e
Files changed (1) hide show
  1. app.py +408 -4
app.py CHANGED
@@ -1,7 +1,411 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
 
 
 
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
1
+ # This demo is adopted from https://github.com/coqui-ai/TTS/blob/dev/TTS/demos/xtts_ft_demo/xtts_demo.py
2
+ # With some modifications to fit the viXTTS model
3
+ import argparse
4
+ import hashlib
5
+ import logging
6
+ import os
7
+ import string
8
+ import subprocess
9
+ import sys
10
+ import tempfile
11
+ from datetime import datetime
12
+
13
  import gradio as gr
14
+ import soundfile as sf
15
+ import torch
16
+ import torchaudio
17
+ from huggingface_hub import hf_hub_download, snapshot_download
18
+ from underthesea import sent_tokenize
19
+ from unidecode import unidecode
20
+ from vinorm import TTSnorm
21
+
22
+ from TTS.tts.configs.xtts_config import XttsConfig
23
+ from TTS.tts.models.xtts import Xtts
24
+
25
+ XTTS_MODEL = None
26
+ SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
27
+ MODEL_DIR = os.path.join(SCRIPT_DIR, "model")
28
+ OUTPUT_DIR = os.path.join(SCRIPT_DIR, "output")
29
+ FILTER_SUFFIX = "_DeepFilterNet3.wav"
30
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
31
+
32
+
33
+ def clear_gpu_cache():
34
+ if torch.cuda.is_available():
35
+ torch.cuda.empty_cache()
36
+
37
+
38
+ def load_model(checkpoint_dir="model/", repo_id="capleaf/viXTTS", use_deepspeed=False):
39
+ global XTTS_MODEL
40
+ clear_gpu_cache()
41
+ os.makedirs(checkpoint_dir, exist_ok=True)
42
+
43
+ required_files = ["model.pth", "config.json", "vocab.json", "speakers_xtts.pth"]
44
+ files_in_dir = os.listdir(checkpoint_dir)
45
+ if not all(file in files_in_dir for file in required_files):
46
+ yield f"Missing model files! Downloading from {repo_id}..."
47
+ snapshot_download(
48
+ repo_id=repo_id,
49
+ repo_type="model",
50
+ local_dir=checkpoint_dir,
51
+ )
52
+ hf_hub_download(
53
+ repo_id="coqui/XTTS-v2",
54
+ filename="speakers_xtts.pth",
55
+ local_dir=checkpoint_dir,
56
+ )
57
+ yield f"Model download finished..."
58
+
59
+ xtts_config = os.path.join(checkpoint_dir, "config.json")
60
+ config = XttsConfig()
61
+ config.load_json(xtts_config)
62
+ XTTS_MODEL = Xtts.init_from_config(config)
63
+ yield "Loading model..."
64
+ XTTS_MODEL.load_checkpoint(
65
+ config, checkpoint_dir=checkpoint_dir, use_deepspeed=use_deepspeed
66
+ )
67
+ if torch.cuda.is_available():
68
+ XTTS_MODEL.cuda()
69
+
70
+ print("Model Loaded!")
71
+ yield "Model Loaded!"
72
+
73
+
74
+ # Define dictionaries to store cached results
75
+ cache_queue = []
76
+ speaker_audio_cache = {}
77
+ filter_cache = {}
78
+ conditioning_latents_cache = {}
79
+
80
+
81
+ def invalidate_cache(cache_limit=50):
82
+ """Invalidate the cache for the oldest key"""
83
+ if len(cache_queue) > cache_limit:
84
+ key_to_remove = cache_queue.pop(0)
85
+ print("Invalidating cache", key_to_remove)
86
+ if os.path.exists(key_to_remove):
87
+ os.remove(key_to_remove)
88
+ if os.path.exists(key_to_remove.replace(".wav", "_DeepFilterNet3.wav")):
89
+ os.remove(key_to_remove.replace(".wav", "_DeepFilterNet3.wav"))
90
+ if key_to_remove in filter_cache:
91
+ del filter_cache[key_to_remove]
92
+ if key_to_remove in conditioning_latents_cache:
93
+ del conditioning_latents_cache[key_to_remove]
94
+
95
+
96
+ def generate_hash(data):
97
+ hash_object = hashlib.md5()
98
+ hash_object.update(data)
99
+ return hash_object.hexdigest()
100
+
101
+
102
+ def get_file_name(text, max_char=50):
103
+ filename = text[:max_char]
104
+ filename = filename.lower()
105
+ filename = filename.replace(" ", "_")
106
+ filename = filename.translate(
107
+ str.maketrans("", "", string.punctuation.replace("_", ""))
108
+ )
109
+ filename = unidecode(filename)
110
+ current_datetime = datetime.now().strftime("%m%d%H%M%S")
111
+ filename = f"{current_datetime}_{filename}"
112
+ return filename
113
+
114
+
115
+ def normalize_vietnamese_text(text):
116
+ text = (
117
+ TTSnorm(text, unknown=False, lower=False, rule=True)
118
+ .replace("..", ".")
119
+ .replace("!.", "!")
120
+ .replace("?.", "?")
121
+ .replace(" .", ".")
122
+ .replace(" ,", ",")
123
+ .replace('"', "")
124
+ .replace("'", "")
125
+ .replace("AI", "Ây Ai")
126
+ .replace("A.I", "Ây Ai")
127
+ )
128
+ return text
129
+
130
+
131
+ def calculate_keep_len(text, lang):
132
+ """Simple hack for short sentences"""
133
+ if lang in ["ja", "zh-cn"]:
134
+ return -1
135
+
136
+ word_count = len(text.split())
137
+ num_punct = text.count(".") + text.count("!") + text.count("?") + text.count(",")
138
+
139
+ if word_count < 5:
140
+ return 15000 * word_count + 2000 * num_punct
141
+ elif word_count < 10:
142
+ return 13000 * word_count + 2000 * num_punct
143
+ return -1
144
+
145
+
146
+ def run_tts(lang, tts_text, speaker_audio_file, use_deepfilter, normalize_text):
147
+ global filter_cache, conditioning_latents_cache, cache_queue
148
+
149
+ if XTTS_MODEL is None:
150
+ return "You need to run the previous step to load the model !!", None, None
151
+
152
+ if not speaker_audio_file:
153
+ return "You need to provide reference audio!!!", None, None
154
+
155
+ # Use the file name as the key, since it's suppose to be unique 💀
156
+ speaker_audio_key = speaker_audio_file
157
+ if not speaker_audio_key in cache_queue:
158
+ cache_queue.append(speaker_audio_key)
159
+ invalidate_cache()
160
+
161
+ # Check if filtered reference is cached
162
+ if use_deepfilter and speaker_audio_key in filter_cache:
163
+ print("Using filter cache...")
164
+ speaker_audio_file = filter_cache[speaker_audio_key]
165
+ elif use_deepfilter:
166
+ print("Running filter...")
167
+ subprocess.run(
168
+ [
169
+ "deepFilter",
170
+ speaker_audio_file,
171
+ "-o",
172
+ os.path.dirname(speaker_audio_file),
173
+ ]
174
+ )
175
+ filter_cache[speaker_audio_key] = speaker_audio_file.replace(
176
+ ".wav", FILTER_SUFFIX
177
+ )
178
+ speaker_audio_file = filter_cache[speaker_audio_key]
179
+
180
+ # Check if conditioning latents are cached
181
+ cache_key = (
182
+ speaker_audio_key,
183
+ XTTS_MODEL.config.gpt_cond_len,
184
+ XTTS_MODEL.config.max_ref_len,
185
+ XTTS_MODEL.config.sound_norm_refs,
186
+ )
187
+ if cache_key in conditioning_latents_cache:
188
+ print("Using conditioning latents cache...")
189
+ gpt_cond_latent, speaker_embedding = conditioning_latents_cache[cache_key]
190
+ else:
191
+ print("Computing conditioning latents...")
192
+ gpt_cond_latent, speaker_embedding = XTTS_MODEL.get_conditioning_latents(
193
+ audio_path=speaker_audio_file,
194
+ gpt_cond_len=XTTS_MODEL.config.gpt_cond_len,
195
+ max_ref_length=XTTS_MODEL.config.max_ref_len,
196
+ sound_norm_refs=XTTS_MODEL.config.sound_norm_refs,
197
+ )
198
+ conditioning_latents_cache[cache_key] = (gpt_cond_latent, speaker_embedding)
199
+
200
+ if normalize_text and lang == "vi":
201
+ tts_text = normalize_vietnamese_text(tts_text)
202
+
203
+ # Split text by sentence
204
+ if lang in ["ja", "zh-cn"]:
205
+ sentences = tts_text.split("。")
206
+ else:
207
+ sentences = sent_tokenize(tts_text)
208
+
209
+ from pprint import pprint
210
+
211
+ pprint(sentences)
212
+
213
+ wav_chunks = []
214
+ for sentence in sentences:
215
+ if sentence.strip() == "":
216
+ continue
217
+ wav_chunk = XTTS_MODEL.inference(
218
+ text=sentence,
219
+ language=lang,
220
+ gpt_cond_latent=gpt_cond_latent,
221
+ speaker_embedding=speaker_embedding,
222
+ # The following values are carefully chosen for viXTTS
223
+ temperature=0.3,
224
+ length_penalty=1.0,
225
+ repetition_penalty=10.0,
226
+ top_k=30,
227
+ top_p=0.85,
228
+ enable_text_splitting=True,
229
+ )
230
+
231
+ keep_len = calculate_keep_len(sentence, lang)
232
+ wav_chunk["wav"] = wav_chunk["wav"][:keep_len]
233
+
234
+ wav_chunks.append(torch.tensor(wav_chunk["wav"]))
235
+
236
+ out_wav = torch.cat(wav_chunks, dim=0).unsqueeze(0)
237
+ gr_audio_id = os.path.basename(os.path.dirname(speaker_audio_file))
238
+ out_path = os.path.join(OUTPUT_DIR, f"{get_file_name(tts_text)}_{gr_audio_id}.wav")
239
+ print("Saving output to ", out_path)
240
+ torchaudio.save(out_path, out_wav, 24000)
241
+
242
+ return "Speech generated !", out_path
243
+
244
+
245
+ # Define a logger to redirect
246
+ class Logger:
247
+ def __init__(self, filename="log.out"):
248
+ self.log_file = filename
249
+ self.terminal = sys.stdout
250
+ self.log = open(self.log_file, "w")
251
+
252
+ def write(self, message):
253
+ self.terminal.write(message)
254
+ self.log.write(message)
255
+
256
+ def flush(self):
257
+ self.terminal.flush()
258
+ self.log.flush()
259
+
260
+ def isatty(self):
261
+ return False
262
+
263
+
264
+ # Redirect stdout and stderr to a file
265
+ sys.stdout = Logger()
266
+ sys.stderr = sys.stdout
267
+
268
+
269
+ logging.basicConfig(
270
+ level=logging.ERROR,
271
+ format="%(asctime)s [%(levelname)s] %(message)s",
272
+ handlers=[logging.StreamHandler(sys.stdout)],
273
+ )
274
+
275
+
276
+ def read_logs():
277
+ sys.stdout.flush()
278
+ with open(sys.stdout.log_file, "r") as f:
279
+ return f.read()
280
+
281
+
282
+ if __name__ == "__main__":
283
+ parser = argparse.ArgumentParser(
284
+ description="""viXTTS inference demo\n\n""",
285
+ formatter_class=argparse.RawTextHelpFormatter,
286
+ )
287
+ parser.add_argument(
288
+ "--port",
289
+ type=int,
290
+ help="Port to run the gradio demo. Default: 5003",
291
+ default=5003,
292
+ )
293
+
294
+ parser.add_argument(
295
+ "--model_dir",
296
+ type=str,
297
+ help="Path to the checkpoint directory. This directory must contain 04 files: model.pth, config.json, vocab.json and speakers_xtts.pth",
298
+ default=None,
299
+ )
300
+
301
+ parser.add_argument(
302
+ "--reference_audio",
303
+ type=str,
304
+ help="Path to the reference audio file.",
305
+ default=None,
306
+ )
307
+
308
+ args = parser.parse_args()
309
+ if args.model_dir:
310
+ MODEL_DIR = os.path.abspath(args.model_dir)
311
+
312
+ REFERENCE_AUDIO = os.path.join(SCRIPT_DIR, "assets", "vixtts_sample_female.wav")
313
+ if args.reference_audio:
314
+ REFERENCE_AUDIO = os.abspath(args.reference_audio)
315
+
316
+ with gr.Blocks() as demo:
317
+ intro = """
318
+ # viXTTS Inference Demo
319
+ Visit viXTTS on HuggingFace: [viXTTS](https://huggingface.co/capleaf/viXTTS)
320
+ """
321
+ gr.Markdown(intro)
322
+ with gr.Row():
323
+ with gr.Column() as col1:
324
+ repo_id = gr.Textbox(
325
+ label="HuggingFace Repo ID",
326
+ value="capleaf/viXTTS",
327
+ )
328
+ checkpoint_dir = gr.Textbox(
329
+ label="viXTTS model directory",
330
+ value=MODEL_DIR,
331
+ )
332
+
333
+ use_deepspeed = gr.Checkbox(
334
+ value=True, label="Use DeepSpeed for faster inference"
335
+ )
336
+
337
+ progress_load = gr.Label(label="Progress:")
338
+ load_btn = gr.Button(
339
+ value="Step 1 - Load viXTTS model", variant="primary"
340
+ )
341
+
342
+ with gr.Column() as col2:
343
+ speaker_reference_audio = gr.Audio(
344
+ label="Speaker reference audio:",
345
+ value=REFERENCE_AUDIO,
346
+ type="filepath",
347
+ )
348
+
349
+ tts_language = gr.Dropdown(
350
+ label="Language",
351
+ value="vi",
352
+ choices=[
353
+ "vi",
354
+ "en",
355
+ "es",
356
+ "fr",
357
+ "de",
358
+ "it",
359
+ "pt",
360
+ "pl",
361
+ "tr",
362
+ "ru",
363
+ "nl",
364
+ "cs",
365
+ "ar",
366
+ "zh",
367
+ "hu",
368
+ "ko",
369
+ "ja",
370
+ ],
371
+ )
372
+
373
+ use_filter = gr.Checkbox(
374
+ label="Denoise Reference Audio",
375
+ value=True,
376
+ )
377
+
378
+ normalize_text = gr.Checkbox(
379
+ label="Normalize Input Text",
380
+ value=True,
381
+ )
382
+
383
+ tts_text = gr.Textbox(
384
+ label="Input Text.",
385
+ value="Xin chào, tôi là một công cụ chuyển đổi văn bản thành giọng nói tiếng Việt được phát triển bởi nhóm Nón lá.",
386
+ )
387
+ tts_btn = gr.Button(value="Step 2 - Inference", variant="primary")
388
+
389
+ with gr.Column() as col3:
390
+ progress_gen = gr.Label(label="Progress:")
391
+ tts_output_audio = gr.Audio(label="Generated Audio.")
392
+
393
+ load_btn.click(
394
+ fn=load_model,
395
+ inputs=[checkpoint_dir, repo_id, use_deepspeed],
396
+ outputs=[progress_load],
397
+ )
398
 
399
+ tts_btn.click(
400
+ fn=run_tts,
401
+ inputs=[
402
+ tts_language,
403
+ tts_text,
404
+ speaker_reference_audio,
405
+ use_filter,
406
+ normalize_text,
407
+ ],
408
+ outputs=[progress_gen, tts_output_audio],
409
+ )
410
 
411
+ demo.launch()