mrfakename commited on
Commit
8474faf
1 Parent(s): 1d03890

Sync from GitHub repo

Browse files

This Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there

Files changed (1) hide show
  1. inference-cli.py +33 -32
inference-cli.py CHANGED
@@ -175,6 +175,32 @@ F5TTS_model_cfg = dict(
175
  )
176
  E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
177
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  def chunk_text(text, max_chars=135):
179
  """
180
  Splits the input text into chunks, each with a maximum number of characters.
@@ -206,26 +232,7 @@ def chunk_text(text, max_chars=135):
206
  #if not Path(ckpt_path).exists():
207
  #ckpt_path = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
208
 
209
- def infer_batch(ref_audio, ref_text, gen_text_batches, model,ckpt_file,file_vocab, remove_silence, cross_fade_duration=0.15):
210
- if model == "F5-TTS":
211
-
212
- if ckpt_file == "":
213
- repo_name= "F5-TTS"
214
- exp_name = "F5TTS_Base"
215
- ckpt_step= 1200000
216
- ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
217
-
218
- ema_model = load_model(DiT, F5TTS_model_cfg, ckpt_file,file_vocab)
219
-
220
- elif model == "E2-TTS":
221
- if ckpt_file == "":
222
- repo_name= "E2-TTS"
223
- exp_name = "E2TTS_Base"
224
- ckpt_step= 1200000
225
- ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
226
-
227
- ema_model = load_model(UNetT, E2TTS_model_cfg, ckpt_file,file_vocab)
228
-
229
  audio, sr = ref_audio
230
  if audio.shape[0] > 1:
231
  audio = torch.mean(audio, dim=0, keepdim=True)
@@ -342,13 +349,7 @@ def process_voice(ref_audio_orig, ref_text):
342
 
343
  if not ref_text.strip():
344
  print("No reference text provided, transcribing reference audio...")
345
- pipe = pipeline(
346
- "automatic-speech-recognition",
347
- model="openai/whisper-large-v3-turbo",
348
- torch_dtype=torch.float16,
349
- device=device,
350
- )
351
- ref_text = pipe(
352
  ref_audio,
353
  chunk_length_s=30,
354
  batch_size=128,
@@ -360,7 +361,7 @@ def process_voice(ref_audio_orig, ref_text):
360
  print("Using custom reference text...")
361
  return ref_audio, ref_text
362
 
363
- def infer(ref_audio, ref_text, gen_text, model,ckpt_file,file_vocab, remove_silence, cross_fade_duration=0.15):
364
  # Add the functionality to ensure it ends with ". "
365
  if not ref_text.endswith(". ") and not ref_text.endswith("。"):
366
  if ref_text.endswith("."):
@@ -376,10 +377,10 @@ def infer(ref_audio, ref_text, gen_text, model,ckpt_file,file_vocab, remove_sile
376
  print(f'gen_text {i}', gen_text)
377
 
378
  print(f"Generating audio using {model} in {len(gen_text_batches)} batches, loading models...")
379
- return infer_batch((audio, sr), ref_text, gen_text_batches, model,ckpt_file,file_vocab, remove_silence, cross_fade_duration)
380
 
381
 
382
- def process(ref_audio, ref_text, text_gen, model,ckpt_file,file_vocab, remove_silence):
383
  main_voice = {"ref_audio":ref_audio, "ref_text":ref_text}
384
  if "voices" not in config:
385
  voices = {"main": main_voice}
@@ -407,7 +408,7 @@ def process(ref_audio, ref_text, text_gen, model,ckpt_file,file_vocab, remove_si
407
  ref_audio = voices[voice]['ref_audio']
408
  ref_text = voices[voice]['ref_text']
409
  print(f"Voice: {voice}")
410
- audio, spectragram = infer(ref_audio, ref_text, gen_text, model,ckpt_file,file_vocab, remove_silence)
411
  generated_audio_segments.append(audio)
412
 
413
  if generated_audio_segments:
@@ -426,4 +427,4 @@ def process(ref_audio, ref_text, text_gen, model,ckpt_file,file_vocab, remove_si
426
  print(f.name)
427
 
428
 
429
- process(ref_audio, ref_text, gen_text, model,ckpt_file,vocab_file, remove_silence)
 
175
  )
176
  E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
177
 
178
+ if model == "F5-TTS":
179
+
180
+ if ckpt_file == "":
181
+ repo_name= "F5-TTS"
182
+ exp_name = "F5TTS_Base"
183
+ ckpt_step= 1200000
184
+ ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
185
+
186
+ ema_model = load_model(DiT, F5TTS_model_cfg, ckpt_file,vocab_file)
187
+
188
+ elif model == "E2-TTS":
189
+ if ckpt_file == "":
190
+ repo_name= "E2-TTS"
191
+ exp_name = "E2TTS_Base"
192
+ ckpt_step= 1200000
193
+ ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
194
+
195
+ ema_model = load_model(UNetT, E2TTS_model_cfg, ckpt_file,vocab_file)
196
+
197
+ asr_pipe = pipeline(
198
+ "automatic-speech-recognition",
199
+ model="openai/whisper-large-v3-turbo",
200
+ torch_dtype=torch.float16,
201
+ device=device,
202
+ )
203
+
204
  def chunk_text(text, max_chars=135):
205
  """
206
  Splits the input text into chunks, each with a maximum number of characters.
 
232
  #if not Path(ckpt_path).exists():
233
  #ckpt_path = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
234
 
235
+ def infer_batch(ref_audio, ref_text, gen_text_batches, model, remove_silence, cross_fade_duration=0.15):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
  audio, sr = ref_audio
237
  if audio.shape[0] > 1:
238
  audio = torch.mean(audio, dim=0, keepdim=True)
 
349
 
350
  if not ref_text.strip():
351
  print("No reference text provided, transcribing reference audio...")
352
+ ref_text = asr_pipe(
 
 
 
 
 
 
353
  ref_audio,
354
  chunk_length_s=30,
355
  batch_size=128,
 
361
  print("Using custom reference text...")
362
  return ref_audio, ref_text
363
 
364
+ def infer(ref_audio, ref_text, gen_text, model, remove_silence, cross_fade_duration=0.15):
365
  # Add the functionality to ensure it ends with ". "
366
  if not ref_text.endswith(". ") and not ref_text.endswith("。"):
367
  if ref_text.endswith("."):
 
377
  print(f'gen_text {i}', gen_text)
378
 
379
  print(f"Generating audio using {model} in {len(gen_text_batches)} batches, loading models...")
380
+ return infer_batch((audio, sr), ref_text, gen_text_batches, model, remove_silence, cross_fade_duration)
381
 
382
 
383
+ def process(ref_audio, ref_text, text_gen, model, remove_silence):
384
  main_voice = {"ref_audio":ref_audio, "ref_text":ref_text}
385
  if "voices" not in config:
386
  voices = {"main": main_voice}
 
408
  ref_audio = voices[voice]['ref_audio']
409
  ref_text = voices[voice]['ref_text']
410
  print(f"Voice: {voice}")
411
+ audio, spectragram = infer(ref_audio, ref_text, gen_text, model,remove_silence)
412
  generated_audio_segments.append(audio)
413
 
414
  if generated_audio_segments:
 
427
  print(f.name)
428
 
429
 
430
+ process(ref_audio, ref_text, gen_text, model, remove_silence)