jason-on-salt-a40 commited on
Commit
d63a00c
·
1 Parent(s): ca27bc7

whisperx, more models, better instructions

Browse files
Files changed (2) hide show
  1. app.py +342 -251
  2. app_old.py +528 -0
app.py CHANGED
@@ -1,6 +1,4 @@
1
  import os
2
- # os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
3
- # os.environ["CUDA_VISIBLE_DEVICES"] = "1" # these are only used if developping locally
4
  import gradio as gr
5
  import torch
6
  import torchaudio
@@ -12,10 +10,19 @@ from models import voicecraft
12
  import io
13
  import numpy as np
14
  import random
 
15
  import spaces
16
 
17
 
18
- whisper_model, voicecraft_model = None, None
 
 
 
 
 
 
 
 
19
 
20
  @spaces.GPU(duration=30)
21
  def seed_everything(seed):
@@ -29,29 +36,71 @@ def seed_everything(seed):
29
  torch.backends.cudnn.deterministic = True
30
 
31
  @spaces.GPU(duration=120)
32
- def load_models(whisper_model_choice, voicecraft_model_choice):
33
- global whisper_model, voicecraft_model
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
- if whisper_model_choice is not None:
36
- import whisper
37
  from whisper.tokenizer import get_tokenizer
38
- whisper_model = {
39
- "model": whisper.load_model(whisper_model_choice),
40
- "tokenizer": get_tokenizer(multilingual=False)
41
- }
 
 
42
 
 
 
43
 
44
- device = "cuda" if torch.cuda.is_available() else "cpu"
45
-
46
- voicecraft_name = f"{voicecraft_model_choice}.pth"
47
- ckpt_fn = f"./pretrained_models/{voicecraft_name}"
48
- encodec_fn = "./pretrained_models/encodec_4cb2048_giga.th"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  if not os.path.exists(ckpt_fn):
50
  os.system(f"wget https://huggingface.co/pyp1/VoiceCraft/resolve/main/{voicecraft_name}\?download\=true")
51
- os.system(f"mv {voicecraft_name}\?download\=true ./pretrained_models/{voicecraft_name}")
52
  if not os.path.exists(encodec_fn):
53
  os.system(f"wget https://huggingface.co/pyp1/VoiceCraft/resolve/main/encodec_4cb2048_giga.th")
54
- os.system(f"mv encodec_4cb2048_giga.th ./pretrained_models/encodec_4cb2048_giga.th")
55
 
56
  ckpt = torch.load(ckpt_fn, map_location="cpu")
57
  model = voicecraft.VoiceCraft(ckpt["config"])
@@ -67,32 +116,78 @@ def load_models(whisper_model_choice, voicecraft_model_choice):
67
 
68
  return gr.Accordion()
69
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  @spaces.GPU(duration=60)
71
  def transcribe(seed, audio_path):
72
- if whisper_model is None:
73
- raise gr.Error("Whisper model not loaded")
74
  seed_everything(seed)
75
-
76
- number_tokens = [
77
- i
78
- for i in range(whisper_model["tokenizer"].eot)
79
- if all(c in "0123456789" for c in whisper_model["tokenizer"].decode([i]).removeprefix(" "))
 
 
 
 
 
80
  ]
81
- result = whisper_model["model"].transcribe(audio_path, suppress_tokens=[-1] + number_tokens, word_timestamps=True)
82
- words = [word_info for segment in result["segments"] for word_info in segment["words"]]
83
-
84
- transcript = result["text"]
85
- transcript_with_start_time = " ".join([f"{word['start']} {word['word']}" for word in words])
86
- transcript_with_end_time = " ".join([f"{word['word']} {word['end']}" for word in words])
87
 
88
- choices = [f"{word['start']} {word['word']} {word['end']}" for word in words]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
  return [
91
- transcript, transcript_with_start_time, transcript_with_end_time,
92
- gr.Dropdown(value=choices[-1], choices=choices, interactive=True), # prompt_to_word
93
- gr.Dropdown(value=choices[0], choices=choices, interactive=True), # edit_from_word
94
- gr.Dropdown(value=choices[-1], choices=choices, interactive=True), # edit_to_word
95
- words
96
  ]
97
 
98
 
@@ -106,12 +201,12 @@ def get_output_audio(audio_tensors, codec_audio_sr):
106
  @spaces.GPU(duration=90)
107
  def run(seed, left_margin, right_margin, codec_audio_sr, codec_sr, top_k, top_p, temperature,
108
  stop_repetition, sample_batch_size, kvcache, silence_tokens,
109
- audio_path, word_info, transcript, smart_transcript,
110
  mode, prompt_end_time, edit_start_time, edit_end_time,
111
  split_text, selected_sentence, previous_audio_tensors):
112
  if voicecraft_model is None:
113
  raise gr.Error("VoiceCraft model not loaded")
114
- if smart_transcript and (word_info is None):
115
  raise gr.Error("Can't use smart transcript: whisper transcript not found")
116
 
117
  seed_everything(seed)
@@ -128,7 +223,6 @@ def run(seed, left_margin, right_margin, codec_audio_sr, codec_sr, top_k, top_p,
128
  else:
129
  sentences = [transcript.replace("\n", " ")]
130
 
131
- device = "cuda" if torch.cuda.is_available() else "cpu"
132
  info = torchaudio.info(audio_path)
133
  audio_dur = info.num_frames / info.sample_rate
134
 
@@ -141,14 +235,14 @@ def run(seed, left_margin, right_margin, codec_audio_sr, codec_sr, top_k, top_p,
141
  if mode != "Edit":
142
  from inference_tts_scale import inference_one_sample
143
 
144
- if smart_transcript:
145
  target_transcript = ""
146
- for word in word_info:
147
  if word["end"] < prompt_end_time:
148
- target_transcript += word["word"]
149
  elif (word["start"] + word["end"]) / 2 < prompt_end_time:
150
  # include part of the word it it's big, but adjust prompt_end_time
151
- target_transcript += word["word"]
152
  prompt_end_time = word["end"]
153
  break
154
  else:
@@ -171,15 +265,15 @@ def run(seed, left_margin, right_margin, codec_audio_sr, codec_sr, top_k, top_p,
171
 
172
  if smart_transcript:
173
  target_transcript = ""
174
- for word in word_info:
175
  if word["start"] < edit_start_time:
176
- target_transcript += word["word"]
177
  else:
178
  break
179
  target_transcript += f" {sentence}"
180
- for word in word_info:
181
  if word["end"] > edit_end_time:
182
- target_transcript += word["word"]
183
  else:
184
  target_transcript = sentence
185
 
@@ -188,7 +282,7 @@ def run(seed, left_margin, right_margin, codec_audio_sr, codec_sr, top_k, top_p,
188
  morphed_span = (max(edit_start_time - left_margin, 1 / codec_sr), min(edit_end_time + right_margin, audio_dur))
189
  mask_interval = [[round(morphed_span[0]*codec_sr), round(morphed_span[1]*codec_sr)]]
190
  mask_interval = torch.LongTensor(mask_interval)
191
-
192
  _, gen_audio = inference_one_sample(voicecraft_model["model"],
193
  voicecraft_model["ckpt"]["config"],
194
  voicecraft_model["ckpt"]["phn2num"],
@@ -207,12 +301,12 @@ def run(seed, left_margin, right_margin, codec_audio_sr, codec_sr, top_k, top_p,
207
  output_audio = get_output_audio(previous_audio_tensors, codec_audio_sr)
208
  sentence_audio = get_output_audio(audio_tensors, codec_audio_sr)
209
  return output_audio, inference_transcript, sentence_audio, previous_audio_tensors
210
-
211
-
212
  def update_input_audio(audio_path):
213
  if audio_path is None:
214
  return 0, 0, 0
215
-
216
  info = torchaudio.info(audio_path)
217
  max_time = round(info.num_frames / info.sample_rate, 2)
218
  return [
@@ -221,7 +315,7 @@ def update_input_audio(audio_path):
221
  gr.Slider(maximum=max_time, value=max_time),
222
  ]
223
 
224
-
225
  def change_mode(mode):
226
  tts_mode_controls, edit_mode_controls, edit_word_mode, split_text, long_tts_sentence_editor
227
  return [
@@ -278,84 +372,52 @@ demo_original_transcript = " But when I had approached so near to them, the comm
278
 
279
  demo_text = {
280
  "TTS": {
281
- "smart": "I cannot believe that the same model can also do text to speech synthesis as well!",
282
- "regular": "But when I had approached so near to them, the common I cannot believe that the same model can also do text to speech synthesis as well!"
283
  },
284
  "Edit": {
285
  "smart": "saw the mirage of the lake in the distance,",
286
  "regular": "But when I saw the mirage of the lake in the distance, which the sense deceives, Lost not by distance any of its marks,"
287
  },
288
  "Long TTS": {
289
- "smart": "You can run generation on a big text!\n"
290
  "Just write it line-by-line. Or sentence-by-sentence.\n"
291
- "If some sentences sound odd, just rerun generation on them, no need to generate the whole text again!",
292
- "regular": "But when I had approached so near to them, the common You can run generation on a big text!\n"
293
  "But when I had approached so near to them, the common Just write it line-by-line. Or sentence-by-sentence.\n"
294
- "But when I had approached so near to them, the common If some sentences sound odd, just rerun generation on them, no need to generate the whole text again!"
295
  }
296
  }
297
 
298
  all_demo_texts = {vv for k, v in demo_text.items() for kk, vv in v.items()}
299
 
300
  demo_words = [
301
- "0.03 but 0.18",
302
- "0.18 when 0.32",
303
- "0.32 i 0.48",
304
- "0.48 had 0.64",
305
- "0.64 approached 1.19",
306
- "1.22 so 1.58",
307
- "1.58 near 1.91",
308
- "1.91 to 2.07",
309
- "2.07 them 2.42",
310
- "2.53 the 2.61",
311
- "2.61 common 3.01",
312
- "3.05 object 3.62",
313
- "3.68 which 3.93",
314
- "3.93 the 4.02",
315
- "4.02 sense 4.34",
316
- "4.34 deceives 4.97",
317
- "5.04 lost 5.54",
318
- "5.54 not 6.00",
319
- "6.00 by 6.14",
320
- "6.14 distance 6.67",
321
- "6.79 any 7.05",
322
- "7.05 of 7.18",
323
- "7.18 its 7.34",
324
- "7.34 marks 7.87"
325
  ]
326
 
327
- demo_word_info = [
328
- {"word": "but", "start": 0.03, "end": 0.18},
329
- {"word": "when", "start": 0.18, "end": 0.32},
330
- {"word": "i", "start": 0.32, "end": 0.48},
331
- {"word": "had", "start": 0.48, "end": 0.64},
332
- {"word": "approached", "start": 0.64, "end": 1.19},
333
- {"word": "so", "start": 1.22, "end": 1.58},
334
- {"word": "near", "start": 1.58, "end": 1.91},
335
- {"word": "to", "start": 1.91, "end": 2.07},
336
- {"word": "them", "start": 2.07, "end": 2.42},
337
- {"word": "the", "start": 2.53, "end": 2.61},
338
- {"word": "common", "start": 2.61, "end": 3.01},
339
- {"word": "object", "start": 3.05, "end": 3.62},
340
- {"word": "which", "start": 3.68, "end": 3.93},
341
- {"word": "the", "start": 3.93, "end": 4.02},
342
- {"word": "sense", "start": 4.02, "end": 4.34},
343
- {"word": "deceives", "start": 4.34, "end": 4.97},
344
- {"word": "lost", "start": 5.04, "end": 5.54},
345
- {"word": "not", "start": 5.54, "end": 6.0},
346
- {"word": "by", "start": 6.0, "end": 6.14},
347
- {"word": "distance", "start": 6.14, "end": 6.67},
348
- {"word": "any", "start": 6.79, "end": 7.05},
349
- {"word": "of", "start": 7.05, "end": 7.18},
350
- {"word": "its", "start": 7.18, "end": 7.34},
351
- {"word": "marks", "start": 7.34, "end": 7.87}
352
  ]
353
 
354
 
355
  def update_demo(mode, smart_transcript, edit_word_mode, transcript, edit_from_word, edit_to_word):
356
  if transcript not in all_demo_texts:
357
  return transcript, edit_from_word, edit_to_word
358
-
359
  replace_half = edit_word_mode == "Replace half"
360
  change_edit_from_word = edit_from_word == demo_words[2] or edit_from_word == demo_words[3]
361
  change_edit_to_word = edit_to_word == demo_words[11] or edit_to_word == demo_words[12]
@@ -368,161 +430,190 @@ def update_demo(mode, smart_transcript, edit_word_mode, transcript, edit_from_wo
368
  ]
369
 
370
 
371
- with gr.Blocks() as app:
372
- with gr.Row():
373
- with gr.Column(scale=2):
374
- load_models_btn = gr.Button(value="Load models")
375
- with gr.Column(scale=5):
376
- with gr.Accordion("Select models", open=False) as models_selector:
377
- with gr.Row():
378
- voicecraft_model_choice = gr.Radio(label="VoiceCraft model", value="giga830M", choices=["giga330M", "giga830M"])
379
- whisper_model_choice = gr.Radio(label="Whisper model", value="base.en",
380
- choices=[None, "tiny.en", "base.en", "small.en", "medium.en", "large"])
381
-
382
- with gr.Row():
383
- with gr.Column(scale=2):
384
- input_audio = gr.Audio(sources=["upload", "microphone"], value="./demo/84_121550_000074_000000.wav", label="Input Audio", type="filepath", interactive=True)
385
- with gr.Group():
386
- original_transcript = gr.Textbox(label="Original transcript", lines=5, value=demo_original_transcript, interactive=False,
387
- info="Use whisper model to get the transcript. Fix it if necessary.")
388
- with gr.Accordion("Word start time", open=False):
389
- transcript_with_start_time = gr.Textbox(label="Start time", lines=5, interactive=False, info="Start time before each word")
390
- with gr.Accordion("Word end time", open=False):
391
- transcript_with_end_time = gr.Textbox(label="End time", lines=5, interactive=False, info="End time after each word")
392
-
393
- transcribe_btn = gr.Button(value="Transcribe")
394
-
395
- with gr.Column(scale=3):
396
- with gr.Group():
397
- transcript = gr.Textbox(label="Text", lines=7, value=demo_text["TTS"]["smart"])
398
- with gr.Row():
399
- smart_transcript = gr.Checkbox(label="Smart transcript", value=True)
400
- with gr.Accordion(label="?", open=False):
401
- info = gr.Markdown(value=smart_transcript_info)
402
-
403
- with gr.Row():
404
- mode = gr.Radio(label="Mode", choices=["TTS", "Edit", "Long TTS"], value="TTS")
405
- split_text = gr.Radio(label="Split text", choices=["Newline", "Sentence"], value="Newline",
406
- info="Split text into parts and run TTS for each part.", visible=False)
407
- edit_word_mode = gr.Radio(label="Edit word mode", choices=["Replace half", "Replace all"], value="Replace half",
408
- info="What to do with first and last word", visible=False)
409
-
410
- with gr.Group() as tts_mode_controls:
411
- prompt_to_word = gr.Dropdown(label="Last word in prompt", choices=demo_words, value=demo_words[10], interactive=True)
412
- prompt_end_time = gr.Slider(label="Prompt end time", minimum=0, maximum=7.93, step=0.01, value=3.01)
413
-
414
- with gr.Group(visible=False) as edit_mode_controls:
415
  with gr.Row():
416
- edit_from_word = gr.Dropdown(label="First word to edit", choices=demo_words, value=demo_words[2], interactive=True)
417
- edit_to_word = gr.Dropdown(label="Last word to edit", choices=demo_words, value=demo_words[12], interactive=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
418
  with gr.Row():
419
- edit_start_time = gr.Slider(label="Edit from time", minimum=0, maximum=7.93, step=0.01, value=0.35)
420
- edit_end_time = gr.Slider(label="Edit to time", minimum=0, maximum=7.93, step=0.01, value=3.75)
421
-
422
- run_btn = gr.Button(value="Run")
423
-
424
- with gr.Column(scale=2):
425
- output_audio = gr.Audio(label="Output Audio")
426
- with gr.Accordion("Inference transcript", open=False):
427
- inference_transcript = gr.Textbox(label="Inference transcript", lines=5, interactive=False,
428
- info="Inference was performed on this transcript.")
429
- with gr.Group(visible=False) as long_tts_sentence_editor:
430
- sentence_selector = gr.Dropdown(label="Sentence", value=None,
431
- info="Select sentence you want to regenerate")
432
- sentence_audio = gr.Audio(label="Sentence Audio", scale=2)
433
- rerun_btn = gr.Button(value="Rerun")
434
-
435
- with gr.Row():
436
- with gr.Accordion("VoiceCraft config", open=False):
437
- seed = gr.Number(label="seed", value=-1, precision=0)
438
- left_margin = gr.Number(label="left_margin", value=0.08)
439
- right_margin = gr.Number(label="right_margin", value=0.08)
440
- codec_audio_sr = gr.Number(label="codec_audio_sr", value=16000)
441
- codec_sr = gr.Number(label="codec_sr", value=50)
442
- top_k = gr.Number(label="top_k", value=0)
443
- top_p = gr.Number(label="top_p", value=0.8)
444
- temperature = gr.Number(label="temperature", value=1)
445
- stop_repetition = gr.Radio(label="stop_repetition", choices=[-1, 1, 2, 3], value=3,
446
- info="if there are long silence in the generated audio, reduce the stop_repetition to 3, 2 or even 1, -1 = disabled")
447
- sample_batch_size = gr.Number(label="sample_batch_size", value=4, precision=0,
448
- info="generate this many samples and choose the shortest one")
449
- kvcache = gr.Radio(label="kvcache", choices=[0, 1], value=1,
450
- info="set to 0 to use less VRAM, but with slower inference")
451
- silence_tokens = gr.Textbox(label="silence tokens", value="[1388,1898,131]")
452
 
453
-
454
- audio_tensors = gr.State()
455
- word_info = gr.State(value=demo_word_info)
456
-
457
-
458
- mode.change(fn=update_demo,
459
- inputs=[mode, smart_transcript, edit_word_mode, transcript, edit_from_word, edit_to_word],
460
- outputs=[transcript, edit_from_word, edit_to_word])
461
- edit_word_mode.change(fn=update_demo,
462
- inputs=[mode, smart_transcript, edit_word_mode, transcript, edit_from_word, edit_to_word],
463
- outputs=[transcript, edit_from_word, edit_to_word])
464
- smart_transcript.change(fn=update_demo,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
465
  inputs=[mode, smart_transcript, edit_word_mode, transcript, edit_from_word, edit_to_word],
466
  outputs=[transcript, edit_from_word, edit_to_word])
467
-
468
- load_models_btn.click(fn=load_models,
469
- inputs=[whisper_model_choice, voicecraft_model_choice],
470
- outputs=[models_selector])
471
-
472
- input_audio.change(fn=update_input_audio,
473
- inputs=[input_audio],
474
- outputs=[prompt_end_time, edit_start_time, edit_end_time])
475
- transcribe_btn.click(fn=transcribe,
476
- inputs=[seed, input_audio],
477
- outputs=[original_transcript, transcript_with_start_time, transcript_with_end_time,
478
- prompt_to_word, edit_from_word, edit_to_word, word_info])
479
-
480
- mode.change(fn=change_mode,
481
- inputs=[mode],
482
- outputs=[tts_mode_controls, edit_mode_controls, edit_word_mode, split_text, long_tts_sentence_editor])
483
-
484
- run_btn.click(fn=run,
485
- inputs=[
486
- seed, left_margin, right_margin,
487
- codec_audio_sr, codec_sr,
488
- top_k, top_p, temperature,
489
- stop_repetition, sample_batch_size,
490
- kvcache, silence_tokens,
491
- input_audio, word_info, transcript, smart_transcript,
492
- mode, prompt_end_time, edit_start_time, edit_end_time,
493
- split_text, sentence_selector, audio_tensors
494
- ],
495
- outputs=[output_audio, inference_transcript, sentence_selector, audio_tensors])
496
-
497
- sentence_selector.change(fn=load_sentence,
498
- inputs=[sentence_selector, codec_audio_sr, audio_tensors],
499
- outputs=[sentence_audio])
500
- rerun_btn.click(fn=run,
501
  inputs=[
502
  seed, left_margin, right_margin,
503
  codec_audio_sr, codec_sr,
504
  top_k, top_p, temperature,
505
  stop_repetition, sample_batch_size,
506
  kvcache, silence_tokens,
507
- input_audio, word_info, transcript, smart_transcript,
508
- gr.State(value="Rerun"), prompt_end_time, edit_start_time, edit_end_time,
509
  split_text, sentence_selector, audio_tensors
510
  ],
511
- outputs=[output_audio, inference_transcript, sentence_audio, audio_tensors])
512
-
513
- prompt_to_word.change(fn=update_bound_word,
514
- inputs=[gr.State(False), prompt_to_word, gr.State("Replace all")],
515
- outputs=[prompt_end_time])
516
- edit_from_word.change(fn=update_bound_word,
517
- inputs=[gr.State(True), edit_from_word, edit_word_mode],
518
- outputs=[edit_start_time])
519
- edit_to_word.change(fn=update_bound_word,
520
- inputs=[gr.State(False), edit_to_word, edit_word_mode],
521
- outputs=[edit_end_time])
522
- edit_word_mode.change(fn=update_bound_words,
523
- inputs=[edit_from_word, edit_to_word, edit_word_mode],
524
- outputs=[edit_start_time, edit_end_time])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
525
 
526
 
527
  if __name__ == "__main__":
528
- app.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
 
 
2
  import gradio as gr
3
  import torch
4
  import torchaudio
 
10
  import io
11
  import numpy as np
12
  import random
13
+ import uuid
14
  import spaces
15
 
16
 
17
+ DEMO_PATH = os.getenv("DEMO_PATH", "./demo")
18
+ TMP_PATH = os.getenv("TMP_PATH", "./demo/temp")
19
+ MODELS_PATH = os.getenv("MODELS_PATH", "./pretrained_models")
20
+ device = "cuda" if torch.cuda.is_available() else "cpu"
21
+ whisper_model, align_model, voicecraft_model = None, None, None
22
+
23
+
24
+ def get_random_string():
25
+ return "".join(str(uuid.uuid4()).split("-"))
26
 
27
  @spaces.GPU(duration=30)
28
  def seed_everything(seed):
 
36
  torch.backends.cudnn.deterministic = True
37
 
38
  @spaces.GPU(duration=120)
39
+ class WhisperxAlignModel:
40
+ def __init__(self):
41
+ from whisperx import load_align_model
42
+ self.model, self.metadata = load_align_model(language_code="en", device=device)
43
+
44
+ def align(self, segments, audio_path):
45
+ from whisperx import align, load_audio
46
+ audio = load_audio(audio_path)
47
+ return align(segments, self.model, self.metadata, audio, device, return_char_alignments=False)["segments"]
48
+
49
+ @spaces.GPU(duration=120)
50
+ class WhisperModel:
51
+ def __init__(self, model_name):
52
+ from whisper import load_model
53
+ self.model = load_model(model_name, device)
54
 
 
 
55
  from whisper.tokenizer import get_tokenizer
56
+ tokenizer = get_tokenizer(multilingual=False)
57
+ self.supress_tokens = [-1] + [
58
+ i
59
+ for i in range(tokenizer.eot)
60
+ if all(c in "0123456789" for c in tokenizer.decode([i]).removeprefix(" "))
61
+ ]
62
 
63
+ def transcribe(self, audio_path):
64
+ return self.model.transcribe(audio_path, suppress_tokens=self.supress_tokens, word_timestamps=True)["segments"]
65
 
66
+ @spaces.GPU(duration=120)
67
+ class WhisperxModel:
68
+ def __init__(self, model_name, align_model: WhisperxAlignModel):
69
+ from whisperx import load_model
70
+ self.model = load_model(model_name, device, asr_options={"suppress_numerals": True, "max_new_tokens": None, "clip_timestamps": None, "hallucination_silence_threshold": None})
71
+ self.align_model = align_model
72
+
73
+ def transcribe(self, audio_path):
74
+ segments = self.model.transcribe(audio_path, batch_size=8)["segments"]
75
+ return self.align_model.align(segments, audio_path)
76
+
77
+ @spaces.GPU(duration=120)
78
+ def load_models(whisper_backend_name, whisper_model_name, alignment_model_name, voicecraft_model_name):
79
+ global transcribe_model, align_model, voicecraft_model
80
+
81
+ if voicecraft_model_name == "giga330M_TTSEnhanced":
82
+ voicecraft_model_name = "gigaHalfLibri330M_TTSEnhanced_max16s"
83
+
84
+ if alignment_model_name is not None:
85
+ align_model = WhisperxAlignModel()
86
+
87
+ if whisper_model_name is not None:
88
+ if whisper_backend_name == "whisper":
89
+ transcribe_model = WhisperModel(whisper_model_name)
90
+ else:
91
+ if align_model is None:
92
+ raise gr.Error("Align model required for whisperx backend")
93
+ transcribe_model = WhisperxModel(whisper_model_name, align_model)
94
+
95
+ voicecraft_name = f"{voicecraft_model_name}.pth"
96
+ ckpt_fn = f"{MODELS_PATH}/{voicecraft_name}"
97
+ encodec_fn = f"{MODELS_PATH}/encodec_4cb2048_giga.th"
98
  if not os.path.exists(ckpt_fn):
99
  os.system(f"wget https://huggingface.co/pyp1/VoiceCraft/resolve/main/{voicecraft_name}\?download\=true")
100
+ os.system(f"mv {voicecraft_name}\?download\=true {MODELS_PATH}/{voicecraft_name}")
101
  if not os.path.exists(encodec_fn):
102
  os.system(f"wget https://huggingface.co/pyp1/VoiceCraft/resolve/main/encodec_4cb2048_giga.th")
103
+ os.system(f"mv encodec_4cb2048_giga.th {MODELS_PATH}/encodec_4cb2048_giga.th")
104
 
105
  ckpt = torch.load(ckpt_fn, map_location="cpu")
106
  model = voicecraft.VoiceCraft(ckpt["config"])
 
116
 
117
  return gr.Accordion()
118
 
119
+
120
+ def get_transcribe_state(segments):
121
+ words_info = [word_info for segment in segments for word_info in segment["words"]]
122
+ return {
123
+ "segments": segments,
124
+ "transcript": " ".join([segment["text"] for segment in segments]),
125
+ "words_info": words_info,
126
+ "transcript_with_start_time": " ".join([f"{word['start']} {word['word']}" for word in words_info]),
127
+ "transcript_with_end_time": " ".join([f"{word['word']} {word['end']}" for word in words_info]),
128
+ "word_bounds": [f"{word['start']} {word['word']} {word['end']}" for word in words_info]
129
+ }
130
+
131
  @spaces.GPU(duration=60)
132
  def transcribe(seed, audio_path):
133
+ if transcribe_model is None:
134
+ raise gr.Error("Transcription model not loaded")
135
  seed_everything(seed)
136
+
137
+ segments = transcribe_model.transcribe(audio_path)
138
+ state = get_transcribe_state(segments)
139
+
140
+ return [
141
+ state["transcript"], state["transcript_with_start_time"], state["transcript_with_end_time"],
142
+ gr.Dropdown(value=state["word_bounds"][-1], choices=state["word_bounds"], interactive=True), # prompt_to_word
143
+ gr.Dropdown(value=state["word_bounds"][0], choices=state["word_bounds"], interactive=True), # edit_from_word
144
+ gr.Dropdown(value=state["word_bounds"][-1], choices=state["word_bounds"], interactive=True), # edit_to_word
145
+ state
146
  ]
 
 
 
 
 
 
147
 
148
+
149
+ def align_segments(transcript, audio_path):
150
+ from aeneas.executetask import ExecuteTask
151
+ from aeneas.task import Task
152
+ import json
153
+ config_string = 'task_language=eng|os_task_file_format=json|is_text_type=plain'
154
+
155
+ tmp_transcript_path = os.path.join(TMP_PATH, f"{get_random_string()}.txt")
156
+ tmp_sync_map_path = os.path.join(TMP_PATH, f"{get_random_string()}.json")
157
+ with open(tmp_transcript_path, "w") as f:
158
+ f.write(transcript)
159
+
160
+ task = Task(config_string=config_string)
161
+ task.audio_file_path_absolute = os.path.abspath(audio_path)
162
+ task.text_file_path_absolute = os.path.abspath(tmp_transcript_path)
163
+ task.sync_map_file_path_absolute = os.path.abspath(tmp_sync_map_path)
164
+ ExecuteTask(task).execute()
165
+ task.output_sync_map_file()
166
+
167
+ with open(tmp_sync_map_path, "r") as f:
168
+ return json.load(f)
169
+
170
+ @spaces.GPU(duration=90)
171
+ def align(seed, transcript, audio_path):
172
+ if align_model is None:
173
+ raise gr.Error("Align model not loaded")
174
+ seed_everything(seed)
175
+
176
+ fragments = align_segments(transcript, audio_path)
177
+ segments = [{
178
+ "start": float(fragment["begin"]),
179
+ "end": float(fragment["end"]),
180
+ "text": " ".join(fragment["lines"])
181
+ } for fragment in fragments["fragments"]]
182
+ segments = align_model.align(segments, audio_path)
183
+ state = get_transcribe_state(segments)
184
 
185
  return [
186
+ state["transcript_with_start_time"], state["transcript_with_end_time"],
187
+ gr.Dropdown(value=state["word_bounds"][-1], choices=state["word_bounds"], interactive=True), # prompt_to_word
188
+ gr.Dropdown(value=state["word_bounds"][0], choices=state["word_bounds"], interactive=True), # edit_from_word
189
+ gr.Dropdown(value=state["word_bounds"][-1], choices=state["word_bounds"], interactive=True), # edit_to_word
190
+ state
191
  ]
192
 
193
 
 
201
  @spaces.GPU(duration=90)
202
  def run(seed, left_margin, right_margin, codec_audio_sr, codec_sr, top_k, top_p, temperature,
203
  stop_repetition, sample_batch_size, kvcache, silence_tokens,
204
+ audio_path, transcribe_state, transcript, smart_transcript,
205
  mode, prompt_end_time, edit_start_time, edit_end_time,
206
  split_text, selected_sentence, previous_audio_tensors):
207
  if voicecraft_model is None:
208
  raise gr.Error("VoiceCraft model not loaded")
209
+ if smart_transcript and (transcribe_state is None):
210
  raise gr.Error("Can't use smart transcript: whisper transcript not found")
211
 
212
  seed_everything(seed)
 
223
  else:
224
  sentences = [transcript.replace("\n", " ")]
225
 
 
226
  info = torchaudio.info(audio_path)
227
  audio_dur = info.num_frames / info.sample_rate
228
 
 
235
  if mode != "Edit":
236
  from inference_tts_scale import inference_one_sample
237
 
238
+ if smart_transcript:
239
  target_transcript = ""
240
+ for word in transcribe_state["words_info"]:
241
  if word["end"] < prompt_end_time:
242
+ target_transcript += word["word"] + (" " if word["word"][-1] != " " else "")
243
  elif (word["start"] + word["end"]) / 2 < prompt_end_time:
244
  # include part of the word it it's big, but adjust prompt_end_time
245
+ target_transcript += word["word"] + (" " if word["word"][-1] != " " else "")
246
  prompt_end_time = word["end"]
247
  break
248
  else:
 
265
 
266
  if smart_transcript:
267
  target_transcript = ""
268
+ for word in transcribe_state["words_info"]:
269
  if word["start"] < edit_start_time:
270
+ target_transcript += word["word"] + (" " if word["word"][-1] != " " else "")
271
  else:
272
  break
273
  target_transcript += f" {sentence}"
274
+ for word in transcribe_state["words_info"]:
275
  if word["end"] > edit_end_time:
276
+ target_transcript += word["word"] + (" " if word["word"][-1] != " " else "")
277
  else:
278
  target_transcript = sentence
279
 
 
282
  morphed_span = (max(edit_start_time - left_margin, 1 / codec_sr), min(edit_end_time + right_margin, audio_dur))
283
  mask_interval = [[round(morphed_span[0]*codec_sr), round(morphed_span[1]*codec_sr)]]
284
  mask_interval = torch.LongTensor(mask_interval)
285
+
286
  _, gen_audio = inference_one_sample(voicecraft_model["model"],
287
  voicecraft_model["ckpt"]["config"],
288
  voicecraft_model["ckpt"]["phn2num"],
 
301
  output_audio = get_output_audio(previous_audio_tensors, codec_audio_sr)
302
  sentence_audio = get_output_audio(audio_tensors, codec_audio_sr)
303
  return output_audio, inference_transcript, sentence_audio, previous_audio_tensors
304
+
305
+
306
  def update_input_audio(audio_path):
307
  if audio_path is None:
308
  return 0, 0, 0
309
+
310
  info = torchaudio.info(audio_path)
311
  max_time = round(info.num_frames / info.sample_rate, 2)
312
  return [
 
315
  gr.Slider(maximum=max_time, value=max_time),
316
  ]
317
 
318
+
319
  def change_mode(mode):
320
  tts_mode_controls, edit_mode_controls, edit_word_mode, split_text, long_tts_sentence_editor
321
  return [
 
372
 
373
  demo_text = {
374
  "TTS": {
375
+ "smart": "I cannot believe that the same model can also do text to speech synthesis too!",
376
+ "regular": "But when I had approached so near to them, the common I cannot believe that the same model can also do text to speech synthesis too!"
377
  },
378
  "Edit": {
379
  "smart": "saw the mirage of the lake in the distance,",
380
  "regular": "But when I saw the mirage of the lake in the distance, which the sense deceives, Lost not by distance any of its marks,"
381
  },
382
  "Long TTS": {
383
+ "smart": "You can run the model on a big text!\n"
384
  "Just write it line-by-line. Or sentence-by-sentence.\n"
385
+ "If some sentences sound odd, just rerun the model on them, no need to generate the whole text again!",
386
+ "regular": "But when I had approached so near to them, the common You can run the model on a big text!\n"
387
  "But when I had approached so near to them, the common Just write it line-by-line. Or sentence-by-sentence.\n"
388
+ "But when I had approached so near to them, the common If some sentences sound odd, just rerun the model on them, no need to generate the whole text again!"
389
  }
390
  }
391
 
392
  all_demo_texts = {vv for k, v in demo_text.items() for kk, vv in v.items()}
393
 
394
  demo_words = [
395
+ '0.029 But 0.149', '0.189 when 0.33', '0.43 I 0.49', '0.53 had 0.65', '0.711 approached 1.152', '1.352 so 1.593',
396
+ '1.693 near 1.933', '1.994 to 2.074', '2.134 them, 2.354', '2.535 the 2.655', '2.695 common 3.016', '3.196 object, 3.577',
397
+ '3.717 which 3.898', '3.958 the 4.058', '4.098 sense 4.359', '4.419 deceives, 4.92', '5.101 lost 5.481', '5.682 not 5.963',
398
+ '6.043 by 6.183', '6.223 distance 6.644', '6.905 any 7.065', '7.125 of 7.185', '7.245 its 7.346', '7.406 marks. 7.727'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
399
  ]
400
 
401
+ demo_words_info = [
402
+ {'word': 'But', 'start': 0.029, 'end': 0.149, 'score': 0.834}, {'word': 'when', 'start': 0.189, 'end': 0.33, 'score': 0.879},
403
+ {'word': 'I', 'start': 0.43, 'end': 0.49, 'score': 0.984}, {'word': 'had', 'start': 0.53, 'end': 0.65, 'score': 0.998},
404
+ {'word': 'approached', 'start': 0.711, 'end': 1.152, 'score': 0.822}, {'word': 'so', 'start': 1.352, 'end': 1.593, 'score': 0.822},
405
+ {'word': 'near', 'start': 1.693, 'end': 1.933, 'score': 0.752}, {'word': 'to', 'start': 1.994, 'end': 2.074, 'score': 0.924},
406
+ {'word': 'them,', 'start': 2.134, 'end': 2.354, 'score': 0.914}, {'word': 'the', 'start': 2.535, 'end': 2.655, 'score': 0.818},
407
+ {'word': 'common', 'start': 2.695, 'end': 3.016, 'score': 0.971}, {'word': 'object,', 'start': 3.196, 'end': 3.577, 'score': 0.823},
408
+ {'word': 'which', 'start': 3.717, 'end': 3.898, 'score': 0.701}, {'word': 'the', 'start': 3.958, 'end': 4.058, 'score': 0.798},
409
+ {'word': 'sense', 'start': 4.098, 'end': 4.359, 'score': 0.797}, {'word': 'deceives,', 'start': 4.419, 'end': 4.92, 'score': 0.802},
410
+ {'word': 'lost', 'start': 5.101, 'end': 5.481, 'score': 0.71}, {'word': 'not', 'start': 5.682, 'end': 5.963, 'score': 0.781},
411
+ {'word': 'by', 'start': 6.043, 'end': 6.183, 'score': 0.834}, {'word': 'distance', 'start': 6.223, 'end': 6.644, 'score': 0.899},
412
+ {'word': 'any', 'start': 6.905, 'end': 7.065, 'score': 0.893}, {'word': 'of', 'start': 7.125, 'end': 7.185, 'score': 0.772},
413
+ {'word': 'its', 'start': 7.245, 'end': 7.346, 'score': 0.778}, {'word': 'marks.', 'start': 7.406, 'end': 7.727, 'score': 0.955}
 
 
 
 
 
 
 
 
 
 
 
 
414
  ]
415
 
416
 
417
  def update_demo(mode, smart_transcript, edit_word_mode, transcript, edit_from_word, edit_to_word):
418
  if transcript not in all_demo_texts:
419
  return transcript, edit_from_word, edit_to_word
420
+
421
  replace_half = edit_word_mode == "Replace half"
422
  change_edit_from_word = edit_from_word == demo_words[2] or edit_from_word == demo_words[3]
423
  change_edit_to_word = edit_to_word == demo_words[11] or edit_to_word == demo_words[12]
 
430
  ]
431
 
432
 
433
+ def get_app():
434
+ with gr.Blocks() as app:
435
+ with gr.Row():
436
+ with gr.Column(scale=2):
437
+ load_models_btn = gr.Button(value="Load models")
438
+ with gr.Column(scale=5):
439
+ with gr.Accordion("Select models", open=False) as models_selector:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
440
  with gr.Row():
441
+ voicecraft_model_choice = gr.Radio(label="VoiceCraft model", value="giga830M",
442
+ choices=["giga330M", "giga830M", "giga330M_TTSEnhanced"])
443
+ whisper_backend_choice = gr.Radio(label="Whisper backend", value="whisperX", choices=["whisper", "whisperX"])
444
+ whisper_model_choice = gr.Radio(label="Whisper model", value="base.en",
445
+ choices=[None, "base.en", "small.en", "medium.en", "large"])
446
+ align_model_choice = gr.Radio(label="Forced alignment model", value="whisperX", choices=[None, "whisperX"])
447
+
448
+ with gr.Row():
449
+ with gr.Column(scale=2):
450
+ input_audio = gr.Audio(value=f"{DEMO_PATH}/84_121550_000074_000000.wav", label="Input Audio", type="filepath", interactive=True)
451
+ with gr.Group():
452
+ original_transcript = gr.Textbox(label="Original transcript", lines=5, value=demo_original_transcript,
453
+ info="Use whisper model to get the transcript. Fix and align it if necessary.")
454
+ with gr.Accordion("Word start time", open=False):
455
+ transcript_with_start_time = gr.Textbox(label="Start time", lines=5, interactive=False, info="Start time before each word")
456
+ with gr.Accordion("Word end time", open=False):
457
+ transcript_with_end_time = gr.Textbox(label="End time", lines=5, interactive=False, info="End time after each word")
458
+
459
+ transcribe_btn = gr.Button(value="Transcribe")
460
+ align_btn = gr.Button(value="Align")
461
+
462
+ with gr.Column(scale=3):
463
+ with gr.Group():
464
+ transcript = gr.Textbox(label="Text", lines=7, value=demo_text["TTS"]["smart"])
465
  with gr.Row():
466
+ smart_transcript = gr.Checkbox(label="Smart transcript", value=True)
467
+ with gr.Accordion(label="?", open=False):
468
+ info = gr.Markdown(value=smart_transcript_info)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
469
 
470
+ with gr.Row():
471
+ mode = gr.Radio(label="Mode", choices=["TTS", "Edit", "Long TTS"], value="TTS")
472
+ split_text = gr.Radio(label="Split text", choices=["Newline", "Sentence"], value="Newline",
473
+ info="Split text into parts and run TTS for each part.", visible=False)
474
+ edit_word_mode = gr.Radio(label="Edit word mode", choices=["Replace half", "Replace all"], value="Replace half",
475
+ info="What to do with first and last word", visible=False)
476
+
477
+ with gr.Group() as tts_mode_controls:
478
+ prompt_to_word = gr.Dropdown(label="Last word in prompt", choices=demo_words, value=demo_words[10], interactive=True)
479
+ prompt_end_time = gr.Slider(label="Prompt end time", minimum=0, maximum=7.93, step=0.001, value=3.016)
480
+
481
+ with gr.Group(visible=False) as edit_mode_controls:
482
+ with gr.Row():
483
+ edit_from_word = gr.Dropdown(label="First word to edit", choices=demo_words, value=demo_words[2], interactive=True)
484
+ edit_to_word = gr.Dropdown(label="Last word to edit", choices=demo_words, value=demo_words[12], interactive=True)
485
+ with gr.Row():
486
+ edit_start_time = gr.Slider(label="Edit from time", minimum=0, maximum=7.93, step=0.001, value=0.46)
487
+ edit_end_time = gr.Slider(label="Edit to time", minimum=0, maximum=7.93, step=0.001, value=3.808)
488
+
489
+ run_btn = gr.Button(value="Run")
490
+
491
+ with gr.Column(scale=2):
492
+ output_audio = gr.Audio(label="Output Audio")
493
+ with gr.Accordion("Inference transcript", open=False):
494
+ inference_transcript = gr.Textbox(label="Inference transcript", lines=5, interactive=False,
495
+ info="Inference was performed on this transcript.")
496
+ with gr.Group(visible=False) as long_tts_sentence_editor:
497
+ sentence_selector = gr.Dropdown(label="Sentence", value=None,
498
+ info="Select sentence you want to regenerate")
499
+ sentence_audio = gr.Audio(label="Sentence Audio", scale=2)
500
+ rerun_btn = gr.Button(value="Rerun")
501
+
502
+ with gr.Row():
503
+ with gr.Accordion("Generation Parameters - change these if you are unhappy with the generation", open=False):
504
+ stop_repetition = gr.Radio(label="stop_repetition", choices=[-1, 1, 2, 3, 4], value=3,
505
+ info="if there are long silence in the generated audio, reduce the stop_repetition to 2 or 1. -1 = disabled")
506
+ sample_batch_size = gr.Number(label="speech rate", value=4, precision=0,
507
+ info="The higher the number, the faster the output will be. "
508
+ "Under the hood, the model will generate this many samples and choose the shortest one. "
509
+ "For giga330M_TTSEnhanced, 1 or 2 should be fine since the model is trained to do TTS.")
510
+ seed = gr.Number(label="seed", value=-1, precision=0, info="random seeds always works :)")
511
+ kvcache = gr.Radio(label="kvcache", choices=[0, 1], value=1,
512
+ info="set to 0 to use less VRAM, but with slower inference")
513
+ left_margin = gr.Number(label="left_margin", value=0.08, info="margin to the left of the editing segment")
514
+ right_margin = gr.Number(label="right_margin", value=0.08, info="margin to the right of the editing segment")
515
+ top_p = gr.Number(label="top_p", value=0.9, info="0.9 is a good value, 0.8 is also good")
516
+ temperature = gr.Number(label="temperature", value=1, info="haven't try other values, do not recommend to change")
517
+ top_k = gr.Number(label="top_k", value=0, info="0 means we don't use topk sampling, because we use topp sampling")
518
+ codec_audio_sr = gr.Number(label="codec_audio_sr", value=16000, info='encodec specific, Do not change')
519
+ codec_sr = gr.Number(label="codec_sr", value=50, info='encodec specific, Do not change')
520
+ silence_tokens = gr.Textbox(label="silence tokens", value="[1388,1898,131]", info="encodec specific, do not change")
521
+
522
+
523
+ audio_tensors = gr.State()
524
+ transcribe_state = gr.State(value={"words_info": demo_words_info})
525
+
526
+
527
+ mode.change(fn=update_demo,
528
+ inputs=[mode, smart_transcript, edit_word_mode, transcript, edit_from_word, edit_to_word],
529
+ outputs=[transcript, edit_from_word, edit_to_word])
530
+ edit_word_mode.change(fn=update_demo,
531
  inputs=[mode, smart_transcript, edit_word_mode, transcript, edit_from_word, edit_to_word],
532
  outputs=[transcript, edit_from_word, edit_to_word])
533
+ smart_transcript.change(fn=update_demo,
534
+ inputs=[mode, smart_transcript, edit_word_mode, transcript, edit_from_word, edit_to_word],
535
+ outputs=[transcript, edit_from_word, edit_to_word])
536
+
537
+ load_models_btn.click(fn=load_models,
538
+ inputs=[whisper_backend_choice, whisper_model_choice, align_model_choice, voicecraft_model_choice],
539
+ outputs=[models_selector])
540
+
541
+ input_audio.upload(fn=update_input_audio,
542
+ inputs=[input_audio],
543
+ outputs=[prompt_end_time, edit_start_time, edit_end_time])
544
+ transcribe_btn.click(fn=transcribe,
545
+ inputs=[seed, input_audio],
546
+ outputs=[original_transcript, transcript_with_start_time, transcript_with_end_time,
547
+ prompt_to_word, edit_from_word, edit_to_word, transcribe_state])
548
+ align_btn.click(fn=align,
549
+ inputs=[seed, original_transcript, input_audio],
550
+ outputs=[transcript_with_start_time, transcript_with_end_time,
551
+ prompt_to_word, edit_from_word, edit_to_word, transcribe_state])
552
+
553
+ mode.change(fn=change_mode,
554
+ inputs=[mode],
555
+ outputs=[tts_mode_controls, edit_mode_controls, edit_word_mode, split_text, long_tts_sentence_editor])
556
+
557
+ run_btn.click(fn=run,
 
 
 
 
 
 
 
 
 
558
  inputs=[
559
  seed, left_margin, right_margin,
560
  codec_audio_sr, codec_sr,
561
  top_k, top_p, temperature,
562
  stop_repetition, sample_batch_size,
563
  kvcache, silence_tokens,
564
+ input_audio, transcribe_state, transcript, smart_transcript,
565
+ mode, prompt_end_time, edit_start_time, edit_end_time,
566
  split_text, sentence_selector, audio_tensors
567
  ],
568
+ outputs=[output_audio, inference_transcript, sentence_selector, audio_tensors])
569
+
570
+ sentence_selector.change(fn=load_sentence,
571
+ inputs=[sentence_selector, codec_audio_sr, audio_tensors],
572
+ outputs=[sentence_audio])
573
+ rerun_btn.click(fn=run,
574
+ inputs=[
575
+ seed, left_margin, right_margin,
576
+ codec_audio_sr, codec_sr,
577
+ top_k, top_p, temperature,
578
+ stop_repetition, sample_batch_size,
579
+ kvcache, silence_tokens,
580
+ input_audio, transcribe_state, transcript, smart_transcript,
581
+ gr.State(value="Rerun"), prompt_end_time, edit_start_time, edit_end_time,
582
+ split_text, sentence_selector, audio_tensors
583
+ ],
584
+ outputs=[output_audio, inference_transcript, sentence_audio, audio_tensors])
585
+
586
+ prompt_to_word.change(fn=update_bound_word,
587
+ inputs=[gr.State(False), prompt_to_word, gr.State("Replace all")],
588
+ outputs=[prompt_end_time])
589
+ edit_from_word.change(fn=update_bound_word,
590
+ inputs=[gr.State(True), edit_from_word, edit_word_mode],
591
+ outputs=[edit_start_time])
592
+ edit_to_word.change(fn=update_bound_word,
593
+ inputs=[gr.State(False), edit_to_word, edit_word_mode],
594
+ outputs=[edit_end_time])
595
+ edit_word_mode.change(fn=update_bound_words,
596
+ inputs=[edit_from_word, edit_to_word, edit_word_mode],
597
+ outputs=[edit_start_time, edit_end_time])
598
+ return app
599
 
600
 
601
  if __name__ == "__main__":
602
+ import argparse
603
+
604
+ parser = argparse.ArgumentParser(description="VoiceCraft gradio app.")
605
+
606
+ parser.add_argument("--demo-path", default="./demo", help="Path to demo directory")
607
+ parser.add_argument("--tmp-path", default="./demo/temp", help="Path to tmp directory")
608
+ parser.add_argument("--models-path", default="./pretrained_models", help="Path to voicecraft models directory")
609
+ parser.add_argument("--port", default=7860, type=int, help="App port")
610
+ parser.add_argument("--share", action="store_true", help="Launch with public url")
611
+
612
+ os.environ["USER"] = os.getenv("USER", "user")
613
+ args = parser.parse_args()
614
+ DEMO_PATH = args.demo_path
615
+ TMP_PATH = args.tmp_path
616
+ MODELS_PATH = args.models_path
617
+
618
+ app = get_app()
619
+ app.queue().launch(share=args.share, server_port=args.port)
app_old.py ADDED
@@ -0,0 +1,528 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ # os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
3
+ # os.environ["CUDA_VISIBLE_DEVICES"] = "1" # these are only used if developping locally
4
+ import gradio as gr
5
+ import torch
6
+ import torchaudio
7
+ from data.tokenizer import (
8
+ AudioTokenizer,
9
+ TextTokenizer,
10
+ )
11
+ from models import voicecraft
12
+ import io
13
+ import numpy as np
14
+ import random
15
+ import spaces
16
+
17
+
18
+ whisper_model, voicecraft_model = None, None
19
+
20
+ @spaces.GPU(duration=30)
21
+ def seed_everything(seed):
22
+ if seed != -1:
23
+ os.environ['PYTHONHASHSEED'] = str(seed)
24
+ random.seed(seed)
25
+ np.random.seed(seed)
26
+ torch.manual_seed(seed)
27
+ torch.cuda.manual_seed(seed)
28
+ torch.backends.cudnn.benchmark = False
29
+ torch.backends.cudnn.deterministic = True
30
+
31
+ @spaces.GPU(duration=120)
32
+ def load_models(whisper_model_choice, voicecraft_model_choice):
33
+ global whisper_model, voicecraft_model
34
+
35
+ if whisper_model_choice is not None:
36
+ import whisper
37
+ from whisper.tokenizer import get_tokenizer
38
+ whisper_model = {
39
+ "model": whisper.load_model(whisper_model_choice),
40
+ "tokenizer": get_tokenizer(multilingual=False)
41
+ }
42
+
43
+
44
+ device = "cuda" if torch.cuda.is_available() else "cpu"
45
+
46
+ voicecraft_name = f"{voicecraft_model_choice}.pth"
47
+ ckpt_fn = f"./pretrained_models/{voicecraft_name}"
48
+ encodec_fn = "./pretrained_models/encodec_4cb2048_giga.th"
49
+ if not os.path.exists(ckpt_fn):
50
+ os.system(f"wget https://huggingface.co/pyp1/VoiceCraft/resolve/main/{voicecraft_name}\?download\=true")
51
+ os.system(f"mv {voicecraft_name}\?download\=true ./pretrained_models/{voicecraft_name}")
52
+ if not os.path.exists(encodec_fn):
53
+ os.system(f"wget https://huggingface.co/pyp1/VoiceCraft/resolve/main/encodec_4cb2048_giga.th")
54
+ os.system(f"mv encodec_4cb2048_giga.th ./pretrained_models/encodec_4cb2048_giga.th")
55
+
56
+ ckpt = torch.load(ckpt_fn, map_location="cpu")
57
+ model = voicecraft.VoiceCraft(ckpt["config"])
58
+ model.load_state_dict(ckpt["model"])
59
+ model.to(device)
60
+ model.eval()
61
+ voicecraft_model = {
62
+ "ckpt": ckpt,
63
+ "model": model,
64
+ "text_tokenizer": TextTokenizer(backend="espeak"),
65
+ "audio_tokenizer": AudioTokenizer(signature=encodec_fn)
66
+ }
67
+
68
+ return gr.Accordion()
69
+
70
+ @spaces.GPU(duration=60)
71
+ def transcribe(seed, audio_path):
72
+ if whisper_model is None:
73
+ raise gr.Error("Whisper model not loaded")
74
+ seed_everything(seed)
75
+
76
+ number_tokens = [
77
+ i
78
+ for i in range(whisper_model["tokenizer"].eot)
79
+ if all(c in "0123456789" for c in whisper_model["tokenizer"].decode([i]).removeprefix(" "))
80
+ ]
81
+ result = whisper_model["model"].transcribe(audio_path, suppress_tokens=[-1] + number_tokens, word_timestamps=True)
82
+ words = [word_info for segment in result["segments"] for word_info in segment["words"]]
83
+
84
+ transcript = result["text"]
85
+ transcript_with_start_time = " ".join([f"{word['start']} {word['word']}" for word in words])
86
+ transcript_with_end_time = " ".join([f"{word['word']} {word['end']}" for word in words])
87
+
88
+ choices = [f"{word['start']} {word['word']} {word['end']}" for word in words]
89
+
90
+ return [
91
+ transcript, transcript_with_start_time, transcript_with_end_time,
92
+ gr.Dropdown(value=choices[-1], choices=choices, interactive=True), # prompt_to_word
93
+ gr.Dropdown(value=choices[0], choices=choices, interactive=True), # edit_from_word
94
+ gr.Dropdown(value=choices[-1], choices=choices, interactive=True), # edit_to_word
95
+ words
96
+ ]
97
+
98
+
99
+ def get_output_audio(audio_tensors, codec_audio_sr):
100
+ result = torch.cat(audio_tensors, 1)
101
+ buffer = io.BytesIO()
102
+ torchaudio.save(buffer, result, int(codec_audio_sr), format="wav")
103
+ buffer.seek(0)
104
+ return buffer.read()
105
+
106
+ @spaces.GPU(duration=90)
107
+ def run(seed, left_margin, right_margin, codec_audio_sr, codec_sr, top_k, top_p, temperature,
108
+ stop_repetition, sample_batch_size, kvcache, silence_tokens,
109
+ audio_path, word_info, transcript, smart_transcript,
110
+ mode, prompt_end_time, edit_start_time, edit_end_time,
111
+ split_text, selected_sentence, previous_audio_tensors):
112
+ if voicecraft_model is None:
113
+ raise gr.Error("VoiceCraft model not loaded")
114
+ if smart_transcript and (word_info is None):
115
+ raise gr.Error("Can't use smart transcript: whisper transcript not found")
116
+
117
+ seed_everything(seed)
118
+ if mode == "Long TTS":
119
+ if split_text == "Newline":
120
+ sentences = transcript.split('\n')
121
+ else:
122
+ from nltk.tokenize import sent_tokenize
123
+ sentences = sent_tokenize(transcript.replace("\n", " "))
124
+ elif mode == "Rerun":
125
+ colon_position = selected_sentence.find(':')
126
+ selected_sentence_idx = int(selected_sentence[:colon_position])
127
+ sentences = [selected_sentence[colon_position + 1:]]
128
+ else:
129
+ sentences = [transcript.replace("\n", " ")]
130
+
131
+ device = "cuda" if torch.cuda.is_available() else "cpu"
132
+ info = torchaudio.info(audio_path)
133
+ audio_dur = info.num_frames / info.sample_rate
134
+
135
+ audio_tensors = []
136
+ inference_transcript = ""
137
+ for sentence in sentences:
138
+ decode_config = {"top_k": top_k, "top_p": top_p, "temperature": temperature, "stop_repetition": stop_repetition,
139
+ "kvcache": kvcache, "codec_audio_sr": codec_audio_sr, "codec_sr": codec_sr,
140
+ "silence_tokens": silence_tokens, "sample_batch_size": sample_batch_size}
141
+ if mode != "Edit":
142
+ from inference_tts_scale import inference_one_sample
143
+
144
+ if smart_transcript:
145
+ target_transcript = ""
146
+ for word in word_info:
147
+ if word["end"] < prompt_end_time:
148
+ target_transcript += word["word"]
149
+ elif (word["start"] + word["end"]) / 2 < prompt_end_time:
150
+ # include part of the word it it's big, but adjust prompt_end_time
151
+ target_transcript += word["word"]
152
+ prompt_end_time = word["end"]
153
+ break
154
+ else:
155
+ break
156
+ target_transcript += f" {sentence}"
157
+ else:
158
+ target_transcript = sentence
159
+
160
+ inference_transcript += target_transcript + "\n"
161
+
162
+ prompt_end_frame = int(min(audio_dur, prompt_end_time) * info.sample_rate)
163
+ _, gen_audio = inference_one_sample(voicecraft_model["model"],
164
+ voicecraft_model["ckpt"]["config"],
165
+ voicecraft_model["ckpt"]["phn2num"],
166
+ voicecraft_model["text_tokenizer"], voicecraft_model["audio_tokenizer"],
167
+ audio_path, target_transcript, device, decode_config,
168
+ prompt_end_frame)
169
+ else:
170
+ from inference_speech_editing_scale import inference_one_sample
171
+
172
+ if smart_transcript:
173
+ target_transcript = ""
174
+ for word in word_info:
175
+ if word["start"] < edit_start_time:
176
+ target_transcript += word["word"]
177
+ else:
178
+ break
179
+ target_transcript += f" {sentence}"
180
+ for word in word_info:
181
+ if word["end"] > edit_end_time:
182
+ target_transcript += word["word"]
183
+ else:
184
+ target_transcript = sentence
185
+
186
+ inference_transcript += target_transcript + "\n"
187
+
188
+ morphed_span = (max(edit_start_time - left_margin, 1 / codec_sr), min(edit_end_time + right_margin, audio_dur))
189
+ mask_interval = [[round(morphed_span[0]*codec_sr), round(morphed_span[1]*codec_sr)]]
190
+ mask_interval = torch.LongTensor(mask_interval)
191
+
192
+ _, gen_audio = inference_one_sample(voicecraft_model["model"],
193
+ voicecraft_model["ckpt"]["config"],
194
+ voicecraft_model["ckpt"]["phn2num"],
195
+ voicecraft_model["text_tokenizer"], voicecraft_model["audio_tokenizer"],
196
+ audio_path, target_transcript, mask_interval, device, decode_config)
197
+ gen_audio = gen_audio[0].cpu()
198
+ audio_tensors.append(gen_audio)
199
+
200
+ if mode != "Rerun":
201
+ output_audio = get_output_audio(audio_tensors, codec_audio_sr)
202
+ sentences = [f"{idx}: {text}" for idx, text in enumerate(sentences)]
203
+ component = gr.Dropdown(choices=sentences, value=sentences[0])
204
+ return output_audio, inference_transcript, component, audio_tensors
205
+ else:
206
+ previous_audio_tensors[selected_sentence_idx] = audio_tensors[0]
207
+ output_audio = get_output_audio(previous_audio_tensors, codec_audio_sr)
208
+ sentence_audio = get_output_audio(audio_tensors, codec_audio_sr)
209
+ return output_audio, inference_transcript, sentence_audio, previous_audio_tensors
210
+
211
+
212
+ def update_input_audio(audio_path):
213
+ if audio_path is None:
214
+ return 0, 0, 0
215
+
216
+ info = torchaudio.info(audio_path)
217
+ max_time = round(info.num_frames / info.sample_rate, 2)
218
+ return [
219
+ gr.Slider(maximum=max_time, value=max_time),
220
+ gr.Slider(maximum=max_time, value=0),
221
+ gr.Slider(maximum=max_time, value=max_time),
222
+ ]
223
+
224
+
225
+ def change_mode(mode):
226
+ tts_mode_controls, edit_mode_controls, edit_word_mode, split_text, long_tts_sentence_editor
227
+ return [
228
+ gr.Group(visible=mode != "Edit"),
229
+ gr.Group(visible=mode == "Edit"),
230
+ gr.Radio(visible=mode == "Edit"),
231
+ gr.Radio(visible=mode == "Long TTS"),
232
+ gr.Group(visible=mode == "Long TTS"),
233
+ ]
234
+
235
+
236
+ def load_sentence(selected_sentence, codec_audio_sr, audio_tensors):
237
+ if selected_sentence is None:
238
+ return None
239
+ colon_position = selected_sentence.find(':')
240
+ selected_sentence_idx = int(selected_sentence[:colon_position])
241
+ return get_output_audio([audio_tensors[selected_sentence_idx]], codec_audio_sr)
242
+
243
+
244
+ def update_bound_word(is_first_word, selected_word, edit_word_mode):
245
+ if selected_word is None:
246
+ return None
247
+
248
+ word_start_time = float(selected_word.split(' ')[0])
249
+ word_end_time = float(selected_word.split(' ')[-1])
250
+ if edit_word_mode == "Replace half":
251
+ bound_time = (word_start_time + word_end_time) / 2
252
+ elif is_first_word:
253
+ bound_time = word_start_time
254
+ else:
255
+ bound_time = word_end_time
256
+
257
+ return bound_time
258
+
259
+
260
+ def update_bound_words(from_selected_word, to_selected_word, edit_word_mode):
261
+ return [
262
+ update_bound_word(True, from_selected_word, edit_word_mode),
263
+ update_bound_word(False, to_selected_word, edit_word_mode),
264
+ ]
265
+
266
+
267
+ smart_transcript_info = """
268
+ If enabled, the target transcript will be constructed for you:</br>
269
+ - In TTS and Long TTS mode just write the text you want to synthesize.</br>
270
+ - In Edit mode just write the text to replace selected editing segment.</br>
271
+ If disabled, you should write the target transcript yourself:</br>
272
+ - In TTS mode write prompt transcript followed by generation transcript.</br>
273
+ - In Long TTS select split by newline (<b>SENTENCE SPLIT WON'T WORK</b>) and start each line with a prompt transcript.</br>
274
+ - In Edit mode write full prompt</br>
275
+ """
276
+
277
+ demo_original_transcript = " But when I had approached so near to them, the common object, which the sense deceives, lost not by distance any of its marks."
278
+
279
+ demo_text = {
280
+ "TTS": {
281
+ "smart": "I cannot believe that the same model can also do text to speech synthesis as well!",
282
+ "regular": "But when I had approached so near to them, the common I cannot believe that the same model can also do text to speech synthesis as well!"
283
+ },
284
+ "Edit": {
285
+ "smart": "saw the mirage of the lake in the distance,",
286
+ "regular": "But when I saw the mirage of the lake in the distance, which the sense deceives, Lost not by distance any of its marks,"
287
+ },
288
+ "Long TTS": {
289
+ "smart": "You can run generation on a big text!\n"
290
+ "Just write it line-by-line. Or sentence-by-sentence.\n"
291
+ "If some sentences sound odd, just rerun generation on them, no need to generate the whole text again!",
292
+ "regular": "But when I had approached so near to them, the common You can run generation on a big text!\n"
293
+ "But when I had approached so near to them, the common Just write it line-by-line. Or sentence-by-sentence.\n"
294
+ "But when I had approached so near to them, the common If some sentences sound odd, just rerun generation on them, no need to generate the whole text again!"
295
+ }
296
+ }
297
+
298
+ all_demo_texts = {vv for k, v in demo_text.items() for kk, vv in v.items()}
299
+
300
+ demo_words = [
301
+ "0.03 but 0.18",
302
+ "0.18 when 0.32",
303
+ "0.32 i 0.48",
304
+ "0.48 had 0.64",
305
+ "0.64 approached 1.19",
306
+ "1.22 so 1.58",
307
+ "1.58 near 1.91",
308
+ "1.91 to 2.07",
309
+ "2.07 them 2.42",
310
+ "2.53 the 2.61",
311
+ "2.61 common 3.01",
312
+ "3.05 object 3.62",
313
+ "3.68 which 3.93",
314
+ "3.93 the 4.02",
315
+ "4.02 sense 4.34",
316
+ "4.34 deceives 4.97",
317
+ "5.04 lost 5.54",
318
+ "5.54 not 6.00",
319
+ "6.00 by 6.14",
320
+ "6.14 distance 6.67",
321
+ "6.79 any 7.05",
322
+ "7.05 of 7.18",
323
+ "7.18 its 7.34",
324
+ "7.34 marks 7.87"
325
+ ]
326
+
327
+ demo_word_info = [
328
+ {"word": "but", "start": 0.03, "end": 0.18},
329
+ {"word": "when", "start": 0.18, "end": 0.32},
330
+ {"word": "i", "start": 0.32, "end": 0.48},
331
+ {"word": "had", "start": 0.48, "end": 0.64},
332
+ {"word": "approached", "start": 0.64, "end": 1.19},
333
+ {"word": "so", "start": 1.22, "end": 1.58},
334
+ {"word": "near", "start": 1.58, "end": 1.91},
335
+ {"word": "to", "start": 1.91, "end": 2.07},
336
+ {"word": "them", "start": 2.07, "end": 2.42},
337
+ {"word": "the", "start": 2.53, "end": 2.61},
338
+ {"word": "common", "start": 2.61, "end": 3.01},
339
+ {"word": "object", "start": 3.05, "end": 3.62},
340
+ {"word": "which", "start": 3.68, "end": 3.93},
341
+ {"word": "the", "start": 3.93, "end": 4.02},
342
+ {"word": "sense", "start": 4.02, "end": 4.34},
343
+ {"word": "deceives", "start": 4.34, "end": 4.97},
344
+ {"word": "lost", "start": 5.04, "end": 5.54},
345
+ {"word": "not", "start": 5.54, "end": 6.0},
346
+ {"word": "by", "start": 6.0, "end": 6.14},
347
+ {"word": "distance", "start": 6.14, "end": 6.67},
348
+ {"word": "any", "start": 6.79, "end": 7.05},
349
+ {"word": "of", "start": 7.05, "end": 7.18},
350
+ {"word": "its", "start": 7.18, "end": 7.34},
351
+ {"word": "marks", "start": 7.34, "end": 7.87}
352
+ ]
353
+
354
+
355
+ def update_demo(mode, smart_transcript, edit_word_mode, transcript, edit_from_word, edit_to_word):
356
+ if transcript not in all_demo_texts:
357
+ return transcript, edit_from_word, edit_to_word
358
+
359
+ replace_half = edit_word_mode == "Replace half"
360
+ change_edit_from_word = edit_from_word == demo_words[2] or edit_from_word == demo_words[3]
361
+ change_edit_to_word = edit_to_word == demo_words[11] or edit_to_word == demo_words[12]
362
+ demo_edit_from_word_value = demo_words[2] if replace_half else demo_words[3]
363
+ demo_edit_to_word_value = demo_words[12] if replace_half else demo_words[11]
364
+ return [
365
+ demo_text[mode]["smart" if smart_transcript else "regular"],
366
+ demo_edit_from_word_value if change_edit_from_word else edit_from_word,
367
+ demo_edit_to_word_value if change_edit_to_word else edit_to_word,
368
+ ]
369
+
370
+
371
+ with gr.Blocks() as app:
372
+ with gr.Row():
373
+ with gr.Column(scale=2):
374
+ load_models_btn = gr.Button(value="Load models")
375
+ with gr.Column(scale=5):
376
+ with gr.Accordion("Select models", open=False) as models_selector:
377
+ with gr.Row():
378
+ voicecraft_model_choice = gr.Radio(label="VoiceCraft model", value="giga830M", choices=["giga330M", "giga830M"])
379
+ whisper_model_choice = gr.Radio(label="Whisper model", value="base.en",
380
+ choices=[None, "tiny.en", "base.en", "small.en", "medium.en", "large"])
381
+
382
+ with gr.Row():
383
+ with gr.Column(scale=2):
384
+ input_audio = gr.Audio(sources=["upload", "microphone"], value="./demo/84_121550_000074_000000.wav", label="Input Audio", type="filepath", interactive=True)
385
+ with gr.Group():
386
+ original_transcript = gr.Textbox(label="Original transcript", lines=5, value=demo_original_transcript, interactive=False,
387
+ info="Use whisper model to get the transcript. Fix it if necessary.")
388
+ with gr.Accordion("Word start time", open=False):
389
+ transcript_with_start_time = gr.Textbox(label="Start time", lines=5, interactive=False, info="Start time before each word")
390
+ with gr.Accordion("Word end time", open=False):
391
+ transcript_with_end_time = gr.Textbox(label="End time", lines=5, interactive=False, info="End time after each word")
392
+
393
+ transcribe_btn = gr.Button(value="Transcribe")
394
+
395
+ with gr.Column(scale=3):
396
+ with gr.Group():
397
+ transcript = gr.Textbox(label="Text", lines=7, value=demo_text["TTS"]["smart"])
398
+ with gr.Row():
399
+ smart_transcript = gr.Checkbox(label="Smart transcript", value=True)
400
+ with gr.Accordion(label="?", open=False):
401
+ info = gr.Markdown(value=smart_transcript_info)
402
+
403
+ with gr.Row():
404
+ mode = gr.Radio(label="Mode", choices=["TTS", "Edit", "Long TTS"], value="TTS")
405
+ split_text = gr.Radio(label="Split text", choices=["Newline", "Sentence"], value="Newline",
406
+ info="Split text into parts and run TTS for each part.", visible=False)
407
+ edit_word_mode = gr.Radio(label="Edit word mode", choices=["Replace half", "Replace all"], value="Replace half",
408
+ info="What to do with first and last word", visible=False)
409
+
410
+ with gr.Group() as tts_mode_controls:
411
+ prompt_to_word = gr.Dropdown(label="Last word in prompt", choices=demo_words, value=demo_words[10], interactive=True)
412
+ prompt_end_time = gr.Slider(label="Prompt end time", minimum=0, maximum=7.93, step=0.01, value=3.01)
413
+
414
+ with gr.Group(visible=False) as edit_mode_controls:
415
+ with gr.Row():
416
+ edit_from_word = gr.Dropdown(label="First word to edit", choices=demo_words, value=demo_words[2], interactive=True)
417
+ edit_to_word = gr.Dropdown(label="Last word to edit", choices=demo_words, value=demo_words[12], interactive=True)
418
+ with gr.Row():
419
+ edit_start_time = gr.Slider(label="Edit from time", minimum=0, maximum=7.93, step=0.01, value=0.35)
420
+ edit_end_time = gr.Slider(label="Edit to time", minimum=0, maximum=7.93, step=0.01, value=3.75)
421
+
422
+ run_btn = gr.Button(value="Run")
423
+
424
+ with gr.Column(scale=2):
425
+ output_audio = gr.Audio(label="Output Audio")
426
+ with gr.Accordion("Inference transcript", open=False):
427
+ inference_transcript = gr.Textbox(label="Inference transcript", lines=5, interactive=False,
428
+ info="Inference was performed on this transcript.")
429
+ with gr.Group(visible=False) as long_tts_sentence_editor:
430
+ sentence_selector = gr.Dropdown(label="Sentence", value=None,
431
+ info="Select sentence you want to regenerate")
432
+ sentence_audio = gr.Audio(label="Sentence Audio", scale=2)
433
+ rerun_btn = gr.Button(value="Rerun")
434
+
435
+ with gr.Row():
436
+ with gr.Accordion("VoiceCraft config", open=False):
437
+ seed = gr.Number(label="seed", value=-1, precision=0)
438
+ left_margin = gr.Number(label="left_margin", value=0.08)
439
+ right_margin = gr.Number(label="right_margin", value=0.08)
440
+ codec_audio_sr = gr.Number(label="codec_audio_sr", value=16000)
441
+ codec_sr = gr.Number(label="codec_sr", value=50)
442
+ top_k = gr.Number(label="top_k", value=0)
443
+ top_p = gr.Number(label="top_p", value=0.8)
444
+ temperature = gr.Number(label="temperature", value=1)
445
+ stop_repetition = gr.Radio(label="stop_repetition", choices=[-1, 1, 2, 3], value=3,
446
+ info="if there are long silence in the generated audio, reduce the stop_repetition to 3, 2 or even 1, -1 = disabled")
447
+ sample_batch_size = gr.Number(label="sample_batch_size", value=4, precision=0,
448
+ info="generate this many samples and choose the shortest one")
449
+ kvcache = gr.Radio(label="kvcache", choices=[0, 1], value=1,
450
+ info="set to 0 to use less VRAM, but with slower inference")
451
+ silence_tokens = gr.Textbox(label="silence tokens", value="[1388,1898,131]")
452
+
453
+
454
+ audio_tensors = gr.State()
455
+ word_info = gr.State(value=demo_word_info)
456
+
457
+
458
+ mode.change(fn=update_demo,
459
+ inputs=[mode, smart_transcript, edit_word_mode, transcript, edit_from_word, edit_to_word],
460
+ outputs=[transcript, edit_from_word, edit_to_word])
461
+ edit_word_mode.change(fn=update_demo,
462
+ inputs=[mode, smart_transcript, edit_word_mode, transcript, edit_from_word, edit_to_word],
463
+ outputs=[transcript, edit_from_word, edit_to_word])
464
+ smart_transcript.change(fn=update_demo,
465
+ inputs=[mode, smart_transcript, edit_word_mode, transcript, edit_from_word, edit_to_word],
466
+ outputs=[transcript, edit_from_word, edit_to_word])
467
+
468
+ load_models_btn.click(fn=load_models,
469
+ inputs=[whisper_model_choice, voicecraft_model_choice],
470
+ outputs=[models_selector])
471
+
472
+ input_audio.change(fn=update_input_audio,
473
+ inputs=[input_audio],
474
+ outputs=[prompt_end_time, edit_start_time, edit_end_time])
475
+ transcribe_btn.click(fn=transcribe,
476
+ inputs=[seed, input_audio],
477
+ outputs=[original_transcript, transcript_with_start_time, transcript_with_end_time,
478
+ prompt_to_word, edit_from_word, edit_to_word, word_info])
479
+
480
+ mode.change(fn=change_mode,
481
+ inputs=[mode],
482
+ outputs=[tts_mode_controls, edit_mode_controls, edit_word_mode, split_text, long_tts_sentence_editor])
483
+
484
+ run_btn.click(fn=run,
485
+ inputs=[
486
+ seed, left_margin, right_margin,
487
+ codec_audio_sr, codec_sr,
488
+ top_k, top_p, temperature,
489
+ stop_repetition, sample_batch_size,
490
+ kvcache, silence_tokens,
491
+ input_audio, word_info, transcript, smart_transcript,
492
+ mode, prompt_end_time, edit_start_time, edit_end_time,
493
+ split_text, sentence_selector, audio_tensors
494
+ ],
495
+ outputs=[output_audio, inference_transcript, sentence_selector, audio_tensors])
496
+
497
+ sentence_selector.change(fn=load_sentence,
498
+ inputs=[sentence_selector, codec_audio_sr, audio_tensors],
499
+ outputs=[sentence_audio])
500
+ rerun_btn.click(fn=run,
501
+ inputs=[
502
+ seed, left_margin, right_margin,
503
+ codec_audio_sr, codec_sr,
504
+ top_k, top_p, temperature,
505
+ stop_repetition, sample_batch_size,
506
+ kvcache, silence_tokens,
507
+ input_audio, word_info, transcript, smart_transcript,
508
+ gr.State(value="Rerun"), prompt_end_time, edit_start_time, edit_end_time,
509
+ split_text, sentence_selector, audio_tensors
510
+ ],
511
+ outputs=[output_audio, inference_transcript, sentence_audio, audio_tensors])
512
+
513
+ prompt_to_word.change(fn=update_bound_word,
514
+ inputs=[gr.State(False), prompt_to_word, gr.State("Replace all")],
515
+ outputs=[prompt_end_time])
516
+ edit_from_word.change(fn=update_bound_word,
517
+ inputs=[gr.State(True), edit_from_word, edit_word_mode],
518
+ outputs=[edit_start_time])
519
+ edit_to_word.change(fn=update_bound_word,
520
+ inputs=[gr.State(False), edit_to_word, edit_word_mode],
521
+ outputs=[edit_end_time])
522
+ edit_word_mode.change(fn=update_bound_words,
523
+ inputs=[edit_from_word, edit_to_word, edit_word_mode],
524
+ outputs=[edit_start_time, edit_end_time])
525
+
526
+
527
+ if __name__ == "__main__":
528
+ app.launch()