VX3 commited on
Commit
197b46f
·
verified ·
2 Parent(s): b99d0d9 1f91df5

Merge branch #mrfakename/E2-F5-TTS' into 'VX3/MimicYouFree'

Browse files
README_REPO.md CHANGED
@@ -147,11 +147,11 @@ Note: Some model components have linting exceptions for E722 to accommodate tens
147
  ## Acknowledgements
148
 
149
  - [E2-TTS](https://arxiv.org/abs/2406.18009) brilliant work, simple and effective
150
- - [Emilia](https://arxiv.org/abs/2407.05361), [WenetSpeech4TTS](https://arxiv.org/abs/2406.05763) valuable datasets
151
  - [lucidrains](https://github.com/lucidrains) initial CFM structure with also [bfs18](https://github.com/bfs18) for discussion
152
  - [SD3](https://arxiv.org/abs/2403.03206) & [Hugging Face diffusers](https://github.com/huggingface/diffusers) DiT and MMDiT code structure
153
- - [torchdiffeq](https://github.com/rtqichen/torchdiffeq) as ODE solver, [Vocos](https://huggingface.co/charactr/vocos-mel-24khz) as vocoder
154
- - [FunASR](https://github.com/modelscope/FunASR), [faster-whisper](https://github.com/SYSTRAN/faster-whisper), [UniSpeech](https://github.com/microsoft/UniSpeech) for evaluation tools
155
  - [ctc-forced-aligner](https://github.com/MahmoudAshraf97/ctc-forced-aligner) for speech edit test
156
  - [mrfakename](https://x.com/realmrfakename) huggingface space demo ~
157
  - [f5-tts-mlx](https://github.com/lucasnewman/f5-tts-mlx/tree/main) Implementation with MLX framework by [Lucas Newman](https://github.com/lucasnewman)
 
147
  ## Acknowledgements
148
 
149
  - [E2-TTS](https://arxiv.org/abs/2406.18009) brilliant work, simple and effective
150
+ - [Emilia](https://arxiv.org/abs/2407.05361), [WenetSpeech4TTS](https://arxiv.org/abs/2406.05763), [LibriTTS](https://arxiv.org/abs/1904.02882), [LJSpeech](https://keithito.com/LJ-Speech-Dataset/) valuable datasets
151
  - [lucidrains](https://github.com/lucidrains) initial CFM structure with also [bfs18](https://github.com/bfs18) for discussion
152
  - [SD3](https://arxiv.org/abs/2403.03206) & [Hugging Face diffusers](https://github.com/huggingface/diffusers) DiT and MMDiT code structure
153
+ - [torchdiffeq](https://github.com/rtqichen/torchdiffeq) as ODE solver, [Vocos](https://huggingface.co/charactr/vocos-mel-24khz) and [BigVGAN](https://github.com/NVIDIA/BigVGAN) as vocoder
154
+ - [FunASR](https://github.com/modelscope/FunASR), [faster-whisper](https://github.com/SYSTRAN/faster-whisper), [UniSpeech](https://github.com/microsoft/UniSpeech), [SpeechMOS](https://github.com/tarepan/SpeechMOS) for evaluation tools
155
  - [ctc-forced-aligner](https://github.com/MahmoudAshraf97/ctc-forced-aligner) for speech edit test
156
  - [mrfakename](https://x.com/realmrfakename) huggingface space demo ~
157
  - [f5-tts-mlx](https://github.com/lucasnewman/f5-tts-mlx/tree/main) Implementation with MLX framework by [Lucas Newman](https://github.com/lucasnewman)
app.py CHANGED
@@ -1,6 +1,7 @@
1
  # ruff: noqa: E402
2
  # Above allows ruff to ignore E402: module level import not at top of file
3
 
 
4
  import re
5
  import tempfile
6
  from collections import OrderedDict
@@ -43,6 +44,12 @@ from f5_tts.infer.utils_infer import (
43
  DEFAULT_TTS_MODEL = "F5-TTS"
44
  tts_model_choice = DEFAULT_TTS_MODEL
45
 
 
 
 
 
 
 
46
 
47
  # load models
48
 
@@ -103,8 +110,24 @@ def generate_response(messages, model, tokenizer):
103
 
104
  @gpu_decorator
105
  def infer(
106
- ref_audio_orig, ref_text, gen_text, model, remove_silence, cross_fade_duration=0.15, speed=1, show_info=gr.Info
 
 
 
 
 
 
 
 
107
  ):
 
 
 
 
 
 
 
 
108
  ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=show_info)
109
 
110
  if model == "F5-TTS":
@@ -120,7 +143,7 @@ def infer(
120
  global custom_ema_model, pre_custom_path
121
  if pre_custom_path != model[1]:
122
  show_info("Loading Custom TTS model...")
123
- custom_ema_model = load_custom(model[1], vocab_path=model[2])
124
  pre_custom_path = model[1]
125
  ema_model = custom_ema_model
126
 
@@ -131,6 +154,7 @@ def infer(
131
  ema_model,
132
  vocoder,
133
  cross_fade_duration=cross_fade_duration,
 
134
  speed=speed,
135
  show_info=show_info,
136
  progress=gr.Progress(),
@@ -184,6 +208,14 @@ with gr.Blocks() as app_tts:
184
  step=0.1,
185
  info="Adjust the speed of the audio.",
186
  )
 
 
 
 
 
 
 
 
187
  cross_fade_duration_slider = gr.Slider(
188
  label="Cross-Fade Duration (s)",
189
  minimum=0.0,
@@ -203,6 +235,7 @@ with gr.Blocks() as app_tts:
203
  gen_text_input,
204
  remove_silence,
205
  cross_fade_duration_slider,
 
206
  speed_slider,
207
  ):
208
  audio_out, spectrogram_path, ref_text_out = infer(
@@ -211,10 +244,11 @@ with gr.Blocks() as app_tts:
211
  gen_text_input,
212
  tts_model_choice,
213
  remove_silence,
214
- cross_fade_duration_slider,
215
- speed_slider,
 
216
  )
217
- return audio_out, spectrogram_path, gr.update(value=ref_text_out)
218
 
219
  generate_btn.click(
220
  basic_tts,
@@ -224,6 +258,7 @@ with gr.Blocks() as app_tts:
224
  gen_text_input,
225
  remove_silence,
226
  cross_fade_duration_slider,
 
227
  speed_slider,
228
  ],
229
  outputs=[audio_output, spectrogram_output, ref_text_input],
@@ -293,7 +328,7 @@ with gr.Blocks() as app_multistyle:
293
  )
294
 
295
  # Regular speech type (mandatory)
296
- with gr.Row():
297
  with gr.Column():
298
  regular_name = gr.Textbox(value="Regular", label="Speech Type Name")
299
  regular_insert = gr.Button("Insert Label", variant="secondary")
@@ -302,12 +337,12 @@ with gr.Blocks() as app_multistyle:
302
 
303
  # Regular speech type (max 100)
304
  max_speech_types = 100
305
- speech_type_rows = [] # 99
306
- speech_type_names = [regular_name] # 100
307
- speech_type_audios = [regular_audio] # 100
308
- speech_type_ref_texts = [regular_ref_text] # 100
309
- speech_type_delete_btns = [] # 99
310
- speech_type_insert_btns = [regular_insert] # 100
311
 
312
  # Additional speech types (99 more)
313
  for i in range(max_speech_types - 1):
@@ -328,51 +363,32 @@ with gr.Blocks() as app_multistyle:
328
  # Button to add speech type
329
  add_speech_type_btn = gr.Button("Add Speech Type")
330
 
331
- # Keep track of current number of speech types
332
- speech_type_count = gr.State(value=1)
333
 
334
  # Function to add a speech type
335
- def add_speech_type_fn(speech_type_count):
 
 
336
  if speech_type_count < max_speech_types:
 
337
  speech_type_count += 1
338
- # Prepare updates for the rows
339
- row_updates = []
340
- for i in range(1, max_speech_types):
341
- if i < speech_type_count:
342
- row_updates.append(gr.update(visible=True))
343
- else:
344
- row_updates.append(gr.update())
345
  else:
346
- # Optionally, show a warning
347
- row_updates = [gr.update() for _ in range(1, max_speech_types)]
348
- return [speech_type_count] + row_updates
349
 
350
- add_speech_type_btn.click(
351
- add_speech_type_fn, inputs=speech_type_count, outputs=[speech_type_count] + speech_type_rows
352
- )
353
 
354
  # Function to delete a speech type
355
- def make_delete_speech_type_fn(index):
356
- def delete_speech_type_fn(speech_type_count):
357
- # Prepare updates
358
- row_updates = []
359
-
360
- for i in range(1, max_speech_types):
361
- if i == index:
362
- row_updates.append(gr.update(visible=False))
363
- else:
364
- row_updates.append(gr.update())
365
-
366
- speech_type_count = max(1, speech_type_count)
367
-
368
- return [speech_type_count] + row_updates
369
-
370
- return delete_speech_type_fn
371
 
372
  # Update delete button clicks
373
- for i, delete_btn in enumerate(speech_type_delete_btns):
374
- delete_fn = make_delete_speech_type_fn(i)
375
- delete_btn.click(delete_fn, inputs=speech_type_count, outputs=[speech_type_count] + speech_type_rows)
 
 
376
 
377
  # Text input for the prompt
378
  gen_text_input_multistyle = gr.Textbox(
@@ -386,7 +402,7 @@ with gr.Blocks() as app_multistyle:
386
  current_text = current_text or ""
387
  speech_type_name = speech_type_name or "None"
388
  updated_text = current_text + f"{{{speech_type_name}}} "
389
- return gr.update(value=updated_text)
390
 
391
  return insert_speech_type_fn
392
 
@@ -446,10 +462,14 @@ with gr.Blocks() as app_multistyle:
446
  if style in speech_types:
447
  current_style = style
448
  else:
449
- # If style not available, default to Regular
450
  current_style = "Regular"
451
 
452
- ref_audio = speech_types[current_style]["audio"]
 
 
 
 
453
  ref_text = speech_types[current_style].get("ref_text", "")
454
 
455
  # Generate speech for this segment
@@ -464,12 +484,10 @@ with gr.Blocks() as app_multistyle:
464
  # Concatenate all audio segments
465
  if generated_audio_segments:
466
  final_audio_data = np.concatenate(generated_audio_segments)
467
- return [(sr, final_audio_data)] + [
468
- gr.update(value=speech_types[style]["ref_text"]) for style in speech_types
469
- ]
470
  else:
471
  gr.Warning("No audio generated.")
472
- return [None] + [gr.update(value=speech_types[style]["ref_text"]) for style in speech_types]
473
 
474
  generate_multistyle_btn.click(
475
  generate_multistyle_speech,
@@ -487,7 +505,7 @@ with gr.Blocks() as app_multistyle:
487
 
488
  # Validation function to disable Generate button if speech types are missing
489
  def validate_speech_types(gen_text, regular_name, *args):
490
- speech_type_names_list = args[:max_speech_types]
491
 
492
  # Collect the speech types names
493
  speech_types_available = set()
@@ -651,7 +669,7 @@ Have a conversation with an AI using your reference voice!
651
  speed=1.0,
652
  show_info=print, # show_info=print no pull to top when generating
653
  )
654
- return audio_result, gr.update(value=ref_text_out)
655
 
656
  def clear_conversation():
657
  """Reset the conversation"""
@@ -744,34 +762,38 @@ If you're having issues, try converting your reference audio to WAV or MP3, clip
744
  """
745
  )
746
 
747
- last_used_custom = files("f5_tts").joinpath("infer/.cache/last_used_custom.txt")
748
 
749
  def load_last_used_custom():
750
  try:
751
- with open(last_used_custom, "r") as f:
752
- return f.read().split(",")
 
 
 
753
  except FileNotFoundError:
754
  last_used_custom.parent.mkdir(parents=True, exist_ok=True)
755
- return [
756
- "hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors",
757
- "hf://SWivid/F5-TTS/F5TTS_Base/vocab.txt",
758
- ]
759
 
760
  def switch_tts_model(new_choice):
761
  global tts_model_choice
762
  if new_choice == "Custom": # override in case webpage is refreshed
763
- custom_ckpt_path, custom_vocab_path = load_last_used_custom()
764
- tts_model_choice = ["Custom", custom_ckpt_path, custom_vocab_path]
765
- return gr.update(visible=True, value=custom_ckpt_path), gr.update(visible=True, value=custom_vocab_path)
 
 
 
 
766
  else:
767
  tts_model_choice = new_choice
768
- return gr.update(visible=False), gr.update(visible=False)
769
 
770
- def set_custom_model(custom_ckpt_path, custom_vocab_path):
771
  global tts_model_choice
772
- tts_model_choice = ["Custom", custom_ckpt_path, custom_vocab_path]
773
- with open(last_used_custom, "w") as f:
774
- f.write(f"{custom_ckpt_path},{custom_vocab_path}")
775
 
776
  with gr.Row():
777
  if not USING_SPACES:
@@ -783,34 +805,49 @@ If you're having issues, try converting your reference audio to WAV or MP3, clip
783
  choices=[DEFAULT_TTS_MODEL, "E2-TTS"], label="Choose TTS Model", value=DEFAULT_TTS_MODEL
784
  )
785
  custom_ckpt_path = gr.Dropdown(
786
- choices=["hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors"],
787
  value=load_last_used_custom()[0],
788
  allow_custom_value=True,
789
- label="MODEL CKPT: local_path | hf://user_id/repo_id/model_ckpt",
790
  visible=False,
791
  )
792
  custom_vocab_path = gr.Dropdown(
793
- choices=["hf://SWivid/F5-TTS/F5TTS_Base/vocab.txt"],
794
  value=load_last_used_custom()[1],
795
  allow_custom_value=True,
796
- label="VOCAB FILE: local_path | hf://user_id/repo_id/vocab_file",
 
 
 
 
 
 
 
 
 
 
797
  visible=False,
798
  )
799
 
800
  choose_tts_model.change(
801
  switch_tts_model,
802
  inputs=[choose_tts_model],
803
- outputs=[custom_ckpt_path, custom_vocab_path],
804
  show_progress="hidden",
805
  )
806
  custom_ckpt_path.change(
807
  set_custom_model,
808
- inputs=[custom_ckpt_path, custom_vocab_path],
809
  show_progress="hidden",
810
  )
811
  custom_vocab_path.change(
812
  set_custom_model,
813
- inputs=[custom_ckpt_path, custom_vocab_path],
 
 
 
 
 
814
  show_progress="hidden",
815
  )
816
 
 
1
  # ruff: noqa: E402
2
  # Above allows ruff to ignore E402: module level import not at top of file
3
 
4
+ import json
5
  import re
6
  import tempfile
7
  from collections import OrderedDict
 
44
  DEFAULT_TTS_MODEL = "F5-TTS"
45
  tts_model_choice = DEFAULT_TTS_MODEL
46
 
47
+ DEFAULT_TTS_MODEL_CFG = [
48
+ "hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors",
49
+ "hf://SWivid/F5-TTS/F5TTS_Base/vocab.txt",
50
+ json.dumps(dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)),
51
+ ]
52
+
53
 
54
  # load models
55
 
 
110
 
111
  @gpu_decorator
112
  def infer(
113
+ ref_audio_orig,
114
+ ref_text,
115
+ gen_text,
116
+ model,
117
+ remove_silence,
118
+ cross_fade_duration=0.15,
119
+ nfe_step=32,
120
+ speed=1,
121
+ show_info=gr.Info,
122
  ):
123
+ if not ref_audio_orig:
124
+ gr.Warning("Please provide reference audio.")
125
+ return gr.update(), gr.update(), ref_text
126
+
127
+ if not gen_text.strip():
128
+ gr.Warning("Please enter text to generate.")
129
+ return gr.update(), gr.update(), ref_text
130
+
131
  ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=show_info)
132
 
133
  if model == "F5-TTS":
 
143
  global custom_ema_model, pre_custom_path
144
  if pre_custom_path != model[1]:
145
  show_info("Loading Custom TTS model...")
146
+ custom_ema_model = load_custom(model[1], vocab_path=model[2], model_cfg=model[3])
147
  pre_custom_path = model[1]
148
  ema_model = custom_ema_model
149
 
 
154
  ema_model,
155
  vocoder,
156
  cross_fade_duration=cross_fade_duration,
157
+ nfe_step=nfe_step,
158
  speed=speed,
159
  show_info=show_info,
160
  progress=gr.Progress(),
 
208
  step=0.1,
209
  info="Adjust the speed of the audio.",
210
  )
211
+ nfe_slider = gr.Slider(
212
+ label="NFE Steps",
213
+ minimum=4,
214
+ maximum=64,
215
+ value=32,
216
+ step=2,
217
+ info="Set the number of denoising steps.",
218
+ )
219
  cross_fade_duration_slider = gr.Slider(
220
  label="Cross-Fade Duration (s)",
221
  minimum=0.0,
 
235
  gen_text_input,
236
  remove_silence,
237
  cross_fade_duration_slider,
238
+ nfe_slider,
239
  speed_slider,
240
  ):
241
  audio_out, spectrogram_path, ref_text_out = infer(
 
244
  gen_text_input,
245
  tts_model_choice,
246
  remove_silence,
247
+ cross_fade_duration=cross_fade_duration_slider,
248
+ nfe_step=nfe_slider,
249
+ speed=speed_slider,
250
  )
251
+ return audio_out, spectrogram_path, ref_text_out
252
 
253
  generate_btn.click(
254
  basic_tts,
 
258
  gen_text_input,
259
  remove_silence,
260
  cross_fade_duration_slider,
261
+ nfe_slider,
262
  speed_slider,
263
  ],
264
  outputs=[audio_output, spectrogram_output, ref_text_input],
 
328
  )
329
 
330
  # Regular speech type (mandatory)
331
+ with gr.Row() as regular_row:
332
  with gr.Column():
333
  regular_name = gr.Textbox(value="Regular", label="Speech Type Name")
334
  regular_insert = gr.Button("Insert Label", variant="secondary")
 
337
 
338
  # Regular speech type (max 100)
339
  max_speech_types = 100
340
+ speech_type_rows = [regular_row]
341
+ speech_type_names = [regular_name]
342
+ speech_type_audios = [regular_audio]
343
+ speech_type_ref_texts = [regular_ref_text]
344
+ speech_type_delete_btns = [None]
345
+ speech_type_insert_btns = [regular_insert]
346
 
347
  # Additional speech types (99 more)
348
  for i in range(max_speech_types - 1):
 
363
  # Button to add speech type
364
  add_speech_type_btn = gr.Button("Add Speech Type")
365
 
366
+ # Keep track of autoincrement of speech types, no roll back
367
+ speech_type_count = 1
368
 
369
  # Function to add a speech type
370
+ def add_speech_type_fn():
371
+ row_updates = [gr.update() for _ in range(max_speech_types)]
372
+ global speech_type_count
373
  if speech_type_count < max_speech_types:
374
+ row_updates[speech_type_count] = gr.update(visible=True)
375
  speech_type_count += 1
 
 
 
 
 
 
 
376
  else:
377
+ gr.Warning("Exhausted maximum number of speech types. Consider restart the app.")
378
+ return row_updates
 
379
 
380
+ add_speech_type_btn.click(add_speech_type_fn, outputs=speech_type_rows)
 
 
381
 
382
  # Function to delete a speech type
383
+ def delete_speech_type_fn():
384
+ return gr.update(visible=False), None, None, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
385
 
386
  # Update delete button clicks
387
+ for i in range(1, len(speech_type_delete_btns)):
388
+ speech_type_delete_btns[i].click(
389
+ delete_speech_type_fn,
390
+ outputs=[speech_type_rows[i], speech_type_names[i], speech_type_audios[i], speech_type_ref_texts[i]],
391
+ )
392
 
393
  # Text input for the prompt
394
  gen_text_input_multistyle = gr.Textbox(
 
402
  current_text = current_text or ""
403
  speech_type_name = speech_type_name or "None"
404
  updated_text = current_text + f"{{{speech_type_name}}} "
405
+ return updated_text
406
 
407
  return insert_speech_type_fn
408
 
 
462
  if style in speech_types:
463
  current_style = style
464
  else:
465
+ gr.Warning(f"Type {style} is not available, will use Regular as default.")
466
  current_style = "Regular"
467
 
468
+ try:
469
+ ref_audio = speech_types[current_style]["audio"]
470
+ except KeyError:
471
+ gr.Warning(f"Please provide reference audio for type {current_style}.")
472
+ return [None] + [speech_types[style]["ref_text"] for style in speech_types]
473
  ref_text = speech_types[current_style].get("ref_text", "")
474
 
475
  # Generate speech for this segment
 
484
  # Concatenate all audio segments
485
  if generated_audio_segments:
486
  final_audio_data = np.concatenate(generated_audio_segments)
487
+ return [(sr, final_audio_data)] + [speech_types[style]["ref_text"] for style in speech_types]
 
 
488
  else:
489
  gr.Warning("No audio generated.")
490
+ return [None] + [speech_types[style]["ref_text"] for style in speech_types]
491
 
492
  generate_multistyle_btn.click(
493
  generate_multistyle_speech,
 
505
 
506
  # Validation function to disable Generate button if speech types are missing
507
  def validate_speech_types(gen_text, regular_name, *args):
508
+ speech_type_names_list = args
509
 
510
  # Collect the speech types names
511
  speech_types_available = set()
 
669
  speed=1.0,
670
  show_info=print, # show_info=print no pull to top when generating
671
  )
672
+ return audio_result, ref_text_out
673
 
674
  def clear_conversation():
675
  """Reset the conversation"""
 
762
  """
763
  )
764
 
765
+ last_used_custom = files("f5_tts").joinpath("infer/.cache/last_used_custom_model_info.txt")
766
 
767
  def load_last_used_custom():
768
  try:
769
+ custom = []
770
+ with open(last_used_custom, "r", encoding="utf-8") as f:
771
+ for line in f:
772
+ custom.append(line.strip())
773
+ return custom
774
  except FileNotFoundError:
775
  last_used_custom.parent.mkdir(parents=True, exist_ok=True)
776
+ return DEFAULT_TTS_MODEL_CFG
 
 
 
777
 
778
  def switch_tts_model(new_choice):
779
  global tts_model_choice
780
  if new_choice == "Custom": # override in case webpage is refreshed
781
+ custom_ckpt_path, custom_vocab_path, custom_model_cfg = load_last_used_custom()
782
+ tts_model_choice = ["Custom", custom_ckpt_path, custom_vocab_path, json.loads(custom_model_cfg)]
783
+ return (
784
+ gr.update(visible=True, value=custom_ckpt_path),
785
+ gr.update(visible=True, value=custom_vocab_path),
786
+ gr.update(visible=True, value=custom_model_cfg),
787
+ )
788
  else:
789
  tts_model_choice = new_choice
790
+ return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
791
 
792
+ def set_custom_model(custom_ckpt_path, custom_vocab_path, custom_model_cfg):
793
  global tts_model_choice
794
+ tts_model_choice = ["Custom", custom_ckpt_path, custom_vocab_path, json.loads(custom_model_cfg)]
795
+ with open(last_used_custom, "w", encoding="utf-8") as f:
796
+ f.write(custom_ckpt_path + "\n" + custom_vocab_path + "\n" + custom_model_cfg + "\n")
797
 
798
  with gr.Row():
799
  if not USING_SPACES:
 
805
  choices=[DEFAULT_TTS_MODEL, "E2-TTS"], label="Choose TTS Model", value=DEFAULT_TTS_MODEL
806
  )
807
  custom_ckpt_path = gr.Dropdown(
808
+ choices=[DEFAULT_TTS_MODEL_CFG[0]],
809
  value=load_last_used_custom()[0],
810
  allow_custom_value=True,
811
+ label="Model: local_path | hf://user_id/repo_id/model_ckpt",
812
  visible=False,
813
  )
814
  custom_vocab_path = gr.Dropdown(
815
+ choices=[DEFAULT_TTS_MODEL_CFG[1]],
816
  value=load_last_used_custom()[1],
817
  allow_custom_value=True,
818
+ label="Vocab: local_path | hf://user_id/repo_id/vocab_file",
819
+ visible=False,
820
+ )
821
+ custom_model_cfg = gr.Dropdown(
822
+ choices=[
823
+ DEFAULT_TTS_MODEL_CFG[2],
824
+ json.dumps(dict(dim=768, depth=18, heads=12, ff_mult=2, text_dim=512, conv_layers=4)),
825
+ ],
826
+ value=load_last_used_custom()[2],
827
+ allow_custom_value=True,
828
+ label="Config: in a dictionary form",
829
  visible=False,
830
  )
831
 
832
  choose_tts_model.change(
833
  switch_tts_model,
834
  inputs=[choose_tts_model],
835
+ outputs=[custom_ckpt_path, custom_vocab_path, custom_model_cfg],
836
  show_progress="hidden",
837
  )
838
  custom_ckpt_path.change(
839
  set_custom_model,
840
+ inputs=[custom_ckpt_path, custom_vocab_path, custom_model_cfg],
841
  show_progress="hidden",
842
  )
843
  custom_vocab_path.change(
844
  set_custom_model,
845
+ inputs=[custom_ckpt_path, custom_vocab_path, custom_model_cfg],
846
+ show_progress="hidden",
847
+ )
848
+ custom_model_cfg.change(
849
+ set_custom_model,
850
+ inputs=[custom_ckpt_path, custom_vocab_path, custom_model_cfg],
851
  show_progress="hidden",
852
  )
853
 
pyproject.toml CHANGED
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
 
5
  [project]
6
  name = "f5-tts"
7
- version = "0.1.2"
8
  description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching"
9
  readme = "README.md"
10
  license = {text = "MIT License"}
@@ -21,6 +21,7 @@ dependencies = [
21
  "datasets",
22
  "ema_pytorch>=0.5.2",
23
  "gradio>=3.45.2",
 
24
  "jieba",
25
  "librosa",
26
  "matplotlib",
 
4
 
5
  [project]
6
  name = "f5-tts"
7
+ version = "0.3.1"
8
  description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching"
9
  readme = "README.md"
10
  license = {text = "MIT License"}
 
21
  "datasets",
22
  "ema_pytorch>=0.5.2",
23
  "gradio>=3.45.2",
24
+ "hydra-core>=1.3.0",
25
  "jieba",
26
  "librosa",
27
  "matplotlib",
src/f5_tts/configs/E2TTS_Base_train.yaml ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ hydra:
2
+ run:
3
+ dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
4
+
5
+ datasets:
6
+ name: Emilia_ZH_EN # dataset name
7
+ batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
8
+ batch_size_type: frame # "frame" or "sample"
9
+ max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
10
+ num_workers: 16
11
+
12
+ optim:
13
+ epochs: 15
14
+ learning_rate: 7.5e-5
15
+ num_warmup_updates: 20000 # warmup steps
16
+ grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
17
+ max_grad_norm: 1.0 # gradient clipping
18
+ bnb_optimizer: False # use bnb 8bit AdamW optimizer or not
19
+
20
+ model:
21
+ name: E2TTS_Base
22
+ tokenizer: pinyin
23
+ tokenizer_path: None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
24
+ arch:
25
+ dim: 1024
26
+ depth: 24
27
+ heads: 16
28
+ ff_mult: 4
29
+ mel_spec:
30
+ target_sample_rate: 24000
31
+ n_mel_channels: 100
32
+ hop_length: 256
33
+ win_length: 1024
34
+ n_fft: 1024
35
+ mel_spec_type: vocos # 'vocos' or 'bigvgan'
36
+ vocoder:
37
+ is_local: False # use local offline ckpt or not
38
+ local_path: None # local vocoder path
39
+
40
+ ckpts:
41
+ logger: wandb # wandb | tensorboard | None
42
+ save_per_updates: 50000 # save checkpoint per steps
43
+ last_per_steps: 5000 # save last checkpoint per steps
44
+ save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
src/f5_tts/configs/E2TTS_Small_train.yaml ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ hydra:
2
+ run:
3
+ dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
4
+
5
+ datasets:
6
+ name: Emilia_ZH_EN
7
+ batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
8
+ batch_size_type: frame # "frame" or "sample"
9
+ max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
10
+ num_workers: 16
11
+
12
+ optim:
13
+ epochs: 15
14
+ learning_rate: 7.5e-5
15
+ num_warmup_updates: 20000 # warmup steps
16
+ grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
17
+ max_grad_norm: 1.0
18
+ bnb_optimizer: False
19
+
20
+ model:
21
+ name: E2TTS_Small
22
+ tokenizer: pinyin
23
+ tokenizer_path: None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
24
+ arch:
25
+ dim: 768
26
+ depth: 20
27
+ heads: 12
28
+ ff_mult: 4
29
+ mel_spec:
30
+ target_sample_rate: 24000
31
+ n_mel_channels: 100
32
+ hop_length: 256
33
+ win_length: 1024
34
+ n_fft: 1024
35
+ mel_spec_type: vocos # 'vocos' or 'bigvgan'
36
+ vocoder:
37
+ is_local: False # use local offline ckpt or not
38
+ local_path: None # local vocoder path
39
+
40
+ ckpts:
41
+ logger: wandb # wandb | tensorboard | None
42
+ save_per_updates: 50000 # save checkpoint per steps
43
+ last_per_steps: 5000 # save last checkpoint per steps
44
+ save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
src/f5_tts/configs/F5TTS_Base_train.yaml ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ hydra:
2
+ run:
3
+ dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
4
+
5
+ datasets:
6
+ name: Emilia_ZH_EN # dataset name
7
+ batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
8
+ batch_size_type: frame # "frame" or "sample"
9
+ max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
10
+ num_workers: 16
11
+
12
+ optim:
13
+ epochs: 15
14
+ learning_rate: 7.5e-5
15
+ num_warmup_updates: 20000 # warmup steps
16
+ grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
17
+ max_grad_norm: 1.0 # gradient clipping
18
+ bnb_optimizer: False # use bnb 8bit AdamW optimizer or not
19
+
20
+ model:
21
+ name: F5TTS_Base # model name
22
+ tokenizer: pinyin # tokenizer type
23
+ tokenizer_path: None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
24
+ arch:
25
+ dim: 1024
26
+ depth: 22
27
+ heads: 16
28
+ ff_mult: 2
29
+ text_dim: 512
30
+ conv_layers: 4
31
+ checkpoint_activations: False # recompute activations and save memory for extra compute
32
+ mel_spec:
33
+ target_sample_rate: 24000
34
+ n_mel_channels: 100
35
+ hop_length: 256
36
+ win_length: 1024
37
+ n_fft: 1024
38
+ mel_spec_type: vocos # 'vocos' or 'bigvgan'
39
+ vocoder:
40
+ is_local: False # use local offline ckpt or not
41
+ local_path: None # local vocoder path
42
+
43
+ ckpts:
44
+ logger: wandb # wandb | tensorboard | None
45
+ save_per_updates: 50000 # save checkpoint per steps
46
+ last_per_steps: 5000 # save last checkpoint per steps
47
+ save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
src/f5_tts/configs/F5TTS_Small_train.yaml ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ hydra:
2
+ run:
3
+ dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
4
+
5
+ datasets:
6
+ name: Emilia_ZH_EN
7
+ batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
8
+ batch_size_type: frame # "frame" or "sample"
9
+ max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
10
+ num_workers: 16
11
+
12
+ optim:
13
+ epochs: 15
14
+ learning_rate: 7.5e-5
15
+ num_warmup_updates: 20000 # warmup steps
16
+ grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
17
+ max_grad_norm: 1.0 # gradient clipping
18
+ bnb_optimizer: False # use bnb 8bit AdamW optimizer or not
19
+
20
+ model:
21
+ name: F5TTS_Small
22
+ tokenizer: pinyin
23
+ tokenizer_path: None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
24
+ arch:
25
+ dim: 768
26
+ depth: 18
27
+ heads: 12
28
+ ff_mult: 2
29
+ text_dim: 512
30
+ conv_layers: 4
31
+ checkpoint_activations: False # recompute activations and save memory for extra compute
32
+ mel_spec:
33
+ target_sample_rate: 24000
34
+ n_mel_channels: 100
35
+ hop_length: 256
36
+ win_length: 1024
37
+ n_fft: 1024
38
+ mel_spec_type: vocos # 'vocos' or 'bigvgan'
39
+ vocoder:
40
+ is_local: False # use local offline ckpt or not
41
+ local_path: None # local vocoder path
42
+
43
+ ckpts:
44
+ logger: wandb # wandb | tensorboard | None
45
+ save_per_updates: 50000 # save checkpoint per steps
46
+ last_per_steps: 5000 # save last checkpoint per steps
47
+ save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
src/f5_tts/eval/README.md CHANGED
@@ -39,11 +39,14 @@ Then update in the following scripts with the paths you put evaluation model ckp
39
 
40
  ### Objective Evaluation
41
 
42
- Update the path with your batch-inferenced results, and carry out WER / SIM evaluations:
43
  ```bash
44
- # Evaluation for Seed-TTS test set
45
- python src/f5_tts/eval/eval_seedtts_testset.py
46
 
47
- # Evaluation for LibriSpeech-PC test-clean (cross-sentence)
48
- python src/f5_tts/eval/eval_librispeech_test_clean.py
49
- ```
 
 
 
 
39
 
40
  ### Objective Evaluation
41
 
42
+ Update the path with your batch-inferenced results, and carry out WER / SIM / UTMOS evaluations:
43
  ```bash
44
+ # Evaluation [WER] for Seed-TTS test [ZH] set
45
+ python src/f5_tts/eval/eval_seedtts_testset.py --eval_task wer --lang zh --gen_wav_dir <GEN_WAV_DIR> --gpu_nums 8
46
 
47
+ # Evaluation [SIM] for LibriSpeech-PC test-clean (cross-sentence)
48
+ python src/f5_tts/eval/eval_librispeech_test_clean.py --eval_task sim --gen_wav_dir <GEN_WAV_DIR> --librispeech_test_clean_path <TEST_CLEAN_PATH>
49
+
50
+ # Evaluation [UTMOS]. --ext: Audio extension
51
+ python src/f5_tts/eval/eval_utmos.py --audio_dir <WAV_DIR> --ext wav
52
+ ```
src/f5_tts/eval/eval_infer_batch.py CHANGED
@@ -34,8 +34,6 @@ win_length = 1024
34
  n_fft = 1024
35
  target_rms = 0.1
36
 
37
-
38
- tokenizer = "pinyin"
39
  rel_path = str(files("f5_tts").joinpath("../../"))
40
 
41
 
@@ -49,6 +47,7 @@ def main():
49
  parser.add_argument("-n", "--expname", required=True)
50
  parser.add_argument("-c", "--ckptstep", default=1200000, type=int)
51
  parser.add_argument("-m", "--mel_spec_type", default="vocos", type=str, choices=["bigvgan", "vocos"])
 
52
 
53
  parser.add_argument("-nfe", "--nfestep", default=32, type=int)
54
  parser.add_argument("-o", "--odemethod", default="euler")
@@ -64,6 +63,7 @@ def main():
64
  ckpt_step = args.ckptstep
65
  ckpt_path = rel_path + f"/ckpts/{exp_name}/model_{ckpt_step}.pt"
66
  mel_spec_type = args.mel_spec_type
 
67
 
68
  nfe_step = args.nfestep
69
  ode_method = args.odemethod
 
34
  n_fft = 1024
35
  target_rms = 0.1
36
 
 
 
37
  rel_path = str(files("f5_tts").joinpath("../../"))
38
 
39
 
 
47
  parser.add_argument("-n", "--expname", required=True)
48
  parser.add_argument("-c", "--ckptstep", default=1200000, type=int)
49
  parser.add_argument("-m", "--mel_spec_type", default="vocos", type=str, choices=["bigvgan", "vocos"])
50
+ parser.add_argument("-to", "--tokenizer", default="pinyin", type=str, choices=["pinyin", "char"])
51
 
52
  parser.add_argument("-nfe", "--nfestep", default=32, type=int)
53
  parser.add_argument("-o", "--odemethod", default="euler")
 
63
  ckpt_step = args.ckptstep
64
  ckpt_path = rel_path + f"/ckpts/{exp_name}/model_{ckpt_step}.pt"
65
  mel_spec_type = args.mel_spec_type
66
+ tokenizer = args.tokenizer
67
 
68
  nfe_step = args.nfestep
69
  ode_method = args.odemethod
src/f5_tts/eval/eval_librispeech_test_clean.py CHANGED
@@ -1,7 +1,9 @@
1
  # Evaluate with Librispeech test-clean, ~3s prompt to generate 4-10s audio (the way of valle/voicebox evaluation)
2
 
3
- import sys
 
4
  import os
 
5
 
6
  sys.path.append(os.getcwd())
7
 
@@ -9,7 +11,6 @@ import multiprocessing as mp
9
  from importlib.resources import files
10
 
11
  import numpy as np
12
-
13
  from f5_tts.eval.utils_eval import (
14
  get_librispeech_test,
15
  run_asr_wer,
@@ -19,55 +20,77 @@ from f5_tts.eval.utils_eval import (
19
  rel_path = str(files("f5_tts").joinpath("../../"))
20
 
21
 
22
- eval_task = "wer" # sim | wer
23
- lang = "en"
24
- metalst = rel_path + "/data/librispeech_pc_test_clean_cross_sentence.lst"
25
- librispeech_test_clean_path = "<SOME_PATH>/LibriSpeech/test-clean" # test-clean path
26
- gen_wav_dir = "PATH_TO_GENERATED" # generated wavs
27
-
28
- gpus = [0, 1, 2, 3, 4, 5, 6, 7]
29
- test_set = get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path)
30
-
31
- ## In LibriSpeech, some speakers utilized varying voice characteristics for different characters in the book,
32
- ## leading to a low similarity for the ground truth in some cases.
33
- # test_set = get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path, eval_ground_truth = True) # eval ground truth
34
-
35
- local = False
36
- if local: # use local custom checkpoint dir
37
- asr_ckpt_dir = "../checkpoints/Systran/faster-whisper-large-v3"
38
- else:
39
- asr_ckpt_dir = "" # auto download to cache dir
40
-
41
- wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth"
42
-
43
-
44
- # --------------------------- WER ---------------------------
45
-
46
- if eval_task == "wer":
47
- wers = []
48
-
49
- with mp.Pool(processes=len(gpus)) as pool:
50
- args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
51
- results = pool.map(run_asr_wer, args)
52
- for wers_ in results:
53
- wers.extend(wers_)
54
-
55
- wer = round(np.mean(wers) * 100, 3)
56
- print(f"\nTotal {len(wers)} samples")
57
- print(f"WER : {wer}%")
58
-
59
-
60
- # --------------------------- SIM ---------------------------
61
-
62
- if eval_task == "sim":
63
- sim_list = []
64
-
65
- with mp.Pool(processes=len(gpus)) as pool:
66
- args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set]
67
- results = pool.map(run_sim, args)
68
- for sim_ in results:
69
- sim_list.extend(sim_)
70
-
71
- sim = round(sum(sim_list) / len(sim_list), 3)
72
- print(f"\nTotal {len(sim_list)} samples")
73
- print(f"SIM : {sim}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # Evaluate with Librispeech test-clean, ~3s prompt to generate 4-10s audio (the way of valle/voicebox evaluation)
2
 
3
+ import argparse
4
+ import json
5
  import os
6
+ import sys
7
 
8
  sys.path.append(os.getcwd())
9
 
 
11
  from importlib.resources import files
12
 
13
  import numpy as np
 
14
  from f5_tts.eval.utils_eval import (
15
  get_librispeech_test,
16
  run_asr_wer,
 
20
  rel_path = str(files("f5_tts").joinpath("../../"))
21
 
22
 
23
+ def get_args():
24
+ parser = argparse.ArgumentParser()
25
+ parser.add_argument("-e", "--eval_task", type=str, default="wer", choices=["sim", "wer"])
26
+ parser.add_argument("-l", "--lang", type=str, default="en")
27
+ parser.add_argument("-g", "--gen_wav_dir", type=str, required=True)
28
+ parser.add_argument("-p", "--librispeech_test_clean_path", type=str, required=True)
29
+ parser.add_argument("-n", "--gpu_nums", type=int, default=8, help="Number of GPUs to use")
30
+ parser.add_argument("--local", action="store_true", help="Use local custom checkpoint directory")
31
+ return parser.parse_args()
32
+
33
+
34
+ def main():
35
+ args = get_args()
36
+ eval_task = args.eval_task
37
+ lang = args.lang
38
+ librispeech_test_clean_path = args.librispeech_test_clean_path # test-clean path
39
+ gen_wav_dir = args.gen_wav_dir
40
+ metalst = rel_path + "/data/librispeech_pc_test_clean_cross_sentence.lst"
41
+
42
+ gpus = list(range(args.gpu_nums))
43
+ test_set = get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path)
44
+
45
+ ## In LibriSpeech, some speakers utilized varying voice characteristics for different characters in the book,
46
+ ## leading to a low similarity for the ground truth in some cases.
47
+ # test_set = get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path, eval_ground_truth = True) # eval ground truth
48
+
49
+ local = args.local
50
+ if local: # use local custom checkpoint dir
51
+ asr_ckpt_dir = "../checkpoints/Systran/faster-whisper-large-v3"
52
+ else:
53
+ asr_ckpt_dir = "" # auto download to cache dir
54
+ wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth"
55
+
56
+ # --------------------------- WER ---------------------------
57
+
58
+ if eval_task == "wer":
59
+ wer_results = []
60
+ wers = []
61
+
62
+ with mp.Pool(processes=len(gpus)) as pool:
63
+ args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
64
+ results = pool.map(run_asr_wer, args)
65
+ for r in results:
66
+ wer_results.extend(r)
67
+
68
+ wer_result_path = f"{gen_wav_dir}/{lang}_wer_results.jsonl"
69
+ with open(wer_result_path, "w") as f:
70
+ for line in wer_results:
71
+ wers.append(line["wer"])
72
+ json_line = json.dumps(line, ensure_ascii=False)
73
+ f.write(json_line + "\n")
74
+
75
+ wer = round(np.mean(wers) * 100, 3)
76
+ print(f"\nTotal {len(wers)} samples")
77
+ print(f"WER : {wer}%")
78
+ print(f"Results have been saved to {wer_result_path}")
79
+
80
+ # --------------------------- SIM ---------------------------
81
+
82
+ if eval_task == "sim":
83
+ sims = []
84
+ with mp.Pool(processes=len(gpus)) as pool:
85
+ args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set]
86
+ results = pool.map(run_sim, args)
87
+ for r in results:
88
+ sims.extend(r)
89
+
90
+ sim = round(sum(sims) / len(sims), 3)
91
+ print(f"\nTotal {len(sims)} samples")
92
+ print(f"SIM : {sim}")
93
+
94
+
95
+ if __name__ == "__main__":
96
+ main()
src/f5_tts/eval/eval_seedtts_testset.py CHANGED
@@ -1,7 +1,9 @@
1
  # Evaluate with Seed-TTS testset
2
 
3
- import sys
 
4
  import os
 
5
 
6
  sys.path.append(os.getcwd())
7
 
@@ -9,7 +11,6 @@ import multiprocessing as mp
9
  from importlib.resources import files
10
 
11
  import numpy as np
12
-
13
  from f5_tts.eval.utils_eval import (
14
  get_seed_tts_test,
15
  run_asr_wer,
@@ -19,57 +20,76 @@ from f5_tts.eval.utils_eval import (
19
  rel_path = str(files("f5_tts").joinpath("../../"))
20
 
21
 
22
- eval_task = "wer" # sim | wer
23
- lang = "zh" # zh | en
24
- metalst = rel_path + f"/data/seedtts_testset/{lang}/meta.lst" # seed-tts testset
25
- # gen_wav_dir = rel_path + f"/data/seedtts_testset/{lang}/wavs" # ground truth wavs
26
- gen_wav_dir = "PATH_TO_GENERATED" # generated wavs
27
-
28
-
29
- # NOTE. paraformer-zh result will be slightly different according to the number of gpus, cuz batchsize is different
30
- # zh 1.254 seems a result of 4 workers wer_seed_tts
31
- gpus = [0, 1, 2, 3, 4, 5, 6, 7]
32
- test_set = get_seed_tts_test(metalst, gen_wav_dir, gpus)
33
-
34
- local = False
35
- if local: # use local custom checkpoint dir
36
- if lang == "zh":
37
- asr_ckpt_dir = "../checkpoints/funasr" # paraformer-zh dir under funasr
38
- elif lang == "en":
39
- asr_ckpt_dir = "../checkpoints/Systran/faster-whisper-large-v3"
40
- else:
41
- asr_ckpt_dir = "" # auto download to cache dir
42
-
43
- wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth"
44
-
45
-
46
- # --------------------------- WER ---------------------------
47
-
48
- if eval_task == "wer":
49
- wers = []
50
-
51
- with mp.Pool(processes=len(gpus)) as pool:
52
- args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
53
- results = pool.map(run_asr_wer, args)
54
- for wers_ in results:
55
- wers.extend(wers_)
56
-
57
- wer = round(np.mean(wers) * 100, 3)
58
- print(f"\nTotal {len(wers)} samples")
59
- print(f"WER : {wer}%")
60
-
61
-
62
- # --------------------------- SIM ---------------------------
63
-
64
- if eval_task == "sim":
65
- sim_list = []
66
-
67
- with mp.Pool(processes=len(gpus)) as pool:
68
- args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set]
69
- results = pool.map(run_sim, args)
70
- for sim_ in results:
71
- sim_list.extend(sim_)
72
-
73
- sim = round(sum(sim_list) / len(sim_list), 3)
74
- print(f"\nTotal {len(sim_list)} samples")
75
- print(f"SIM : {sim}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # Evaluate with Seed-TTS testset
2
 
3
+ import argparse
4
+ import json
5
  import os
6
+ import sys
7
 
8
  sys.path.append(os.getcwd())
9
 
 
11
  from importlib.resources import files
12
 
13
  import numpy as np
 
14
  from f5_tts.eval.utils_eval import (
15
  get_seed_tts_test,
16
  run_asr_wer,
 
20
  rel_path = str(files("f5_tts").joinpath("../../"))
21
 
22
 
23
+ def get_args():
24
+ parser = argparse.ArgumentParser()
25
+ parser.add_argument("-e", "--eval_task", type=str, default="wer", choices=["sim", "wer"])
26
+ parser.add_argument("-l", "--lang", type=str, default="en", choices=["zh", "en"])
27
+ parser.add_argument("-g", "--gen_wav_dir", type=str, required=True)
28
+ parser.add_argument("-n", "--gpu_nums", type=int, default=8, help="Number of GPUs to use")
29
+ parser.add_argument("--local", action="store_true", help="Use local custom checkpoint directory")
30
+ return parser.parse_args()
31
+
32
+
33
+ def main():
34
+ args = get_args()
35
+ eval_task = args.eval_task
36
+ lang = args.lang
37
+ gen_wav_dir = args.gen_wav_dir
38
+ metalst = rel_path + f"/data/seedtts_testset/{lang}/meta.lst" # seed-tts testset
39
+
40
+ # NOTE. paraformer-zh result will be slightly different according to the number of gpus, cuz batchsize is different
41
+ # zh 1.254 seems a result of 4 workers wer_seed_tts
42
+ gpus = list(range(args.gpu_nums))
43
+ test_set = get_seed_tts_test(metalst, gen_wav_dir, gpus)
44
+
45
+ local = args.local
46
+ if local: # use local custom checkpoint dir
47
+ if lang == "zh":
48
+ asr_ckpt_dir = "../checkpoints/funasr" # paraformer-zh dir under funasr
49
+ elif lang == "en":
50
+ asr_ckpt_dir = "../checkpoints/Systran/faster-whisper-large-v3"
51
+ else:
52
+ asr_ckpt_dir = "" # auto download to cache dir
53
+ wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth"
54
+
55
+ # --------------------------- WER ---------------------------
56
+
57
+ if eval_task == "wer":
58
+ wer_results = []
59
+ wers = []
60
+
61
+ with mp.Pool(processes=len(gpus)) as pool:
62
+ args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
63
+ results = pool.map(run_asr_wer, args)
64
+ for r in results:
65
+ wer_results.extend(r)
66
+
67
+ wer_result_path = f"{gen_wav_dir}/{lang}_wer_results.jsonl"
68
+ with open(wer_result_path, "w") as f:
69
+ for line in wer_results:
70
+ wers.append(line["wer"])
71
+ json_line = json.dumps(line, ensure_ascii=False)
72
+ f.write(json_line + "\n")
73
+
74
+ wer = round(np.mean(wers) * 100, 3)
75
+ print(f"\nTotal {len(wers)} samples")
76
+ print(f"WER : {wer}%")
77
+ print(f"Results have been saved to {wer_result_path}")
78
+
79
+ # --------------------------- SIM ---------------------------
80
+
81
+ if eval_task == "sim":
82
+ sims = []
83
+ with mp.Pool(processes=len(gpus)) as pool:
84
+ args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set]
85
+ results = pool.map(run_sim, args)
86
+ for r in results:
87
+ sims.extend(r)
88
+
89
+ sim = round(sum(sims) / len(sims), 3)
90
+ print(f"\nTotal {len(sims)} samples")
91
+ print(f"SIM : {sim}")
92
+
93
+
94
+ if __name__ == "__main__":
95
+ main()
src/f5_tts/eval/eval_utmos.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ from pathlib import Path
4
+
5
+ import librosa
6
+ import torch
7
+ from tqdm import tqdm
8
+
9
+
10
+ def main():
11
+ parser = argparse.ArgumentParser(description="UTMOS Evaluation")
12
+ parser.add_argument("--audio_dir", type=str, required=True, help="Audio file path.")
13
+ parser.add_argument("--ext", type=str, default="wav", help="Audio extension.")
14
+ args = parser.parse_args()
15
+
16
+ device = "cuda" if torch.cuda.is_available() else "cpu"
17
+
18
+ predictor = torch.hub.load("tarepan/SpeechMOS:v1.2.0", "utmos22_strong", trust_repo=True)
19
+ predictor = predictor.to(device)
20
+
21
+ audio_paths = list(Path(args.audio_dir).rglob(f"*.{args.ext}"))
22
+ utmos_results = {}
23
+ utmos_score = 0
24
+
25
+ for audio_path in tqdm(audio_paths, desc="Processing"):
26
+ wav_name = audio_path.stem
27
+ wav, sr = librosa.load(audio_path, sr=None, mono=True)
28
+ wav_tensor = torch.from_numpy(wav).to(device).unsqueeze(0)
29
+ score = predictor(wav_tensor, sr)
30
+ utmos_results[str(wav_name)] = score.item()
31
+ utmos_score += score.item()
32
+
33
+ avg_score = utmos_score / len(audio_paths) if len(audio_paths) > 0 else 0
34
+ print(f"UTMOS: {avg_score}")
35
+
36
+ utmos_result_path = Path(args.audio_dir) / "utmos_results.json"
37
+ with open(utmos_result_path, "w", encoding="utf-8") as f:
38
+ json.dump(utmos_results, f, ensure_ascii=False, indent=4)
39
+
40
+ print(f"Results have been saved to {utmos_result_path}")
41
+
42
+
43
+ if __name__ == "__main__":
44
+ main()
src/f5_tts/eval/utils_eval.py CHANGED
@@ -2,6 +2,7 @@ import math
2
  import os
3
  import random
4
  import string
 
5
 
6
  import torch
7
  import torch.nn.functional as F
@@ -320,7 +321,7 @@ def run_asr_wer(args):
320
  from zhon.hanzi import punctuation
321
 
322
  punctuation_all = punctuation + string.punctuation
323
- wers = []
324
 
325
  from jiwer import compute_measures
326
 
@@ -335,8 +336,8 @@ def run_asr_wer(args):
335
  for segment in segments:
336
  hypo = hypo + " " + segment.text
337
 
338
- # raw_truth = truth
339
- # raw_hypo = hypo
340
 
341
  for x in punctuation_all:
342
  truth = truth.replace(x, "")
@@ -360,9 +361,16 @@ def run_asr_wer(args):
360
  # dele = measures["deletions"] / len(ref_list)
361
  # inse = measures["insertions"] / len(ref_list)
362
 
363
- wers.append(wer)
 
 
 
 
 
 
 
364
 
365
- return wers
366
 
367
 
368
  # SIM Evaluation
@@ -381,7 +389,7 @@ def run_sim(args):
381
  model = model.cuda(device)
382
  model.eval()
383
 
384
- sim_list = []
385
  for wav1, wav2, truth in tqdm(test_set):
386
  wav1, sr1 = torchaudio.load(wav1)
387
  wav2, sr2 = torchaudio.load(wav2)
@@ -400,6 +408,6 @@ def run_sim(args):
400
 
401
  sim = F.cosine_similarity(emb1, emb2)[0].item()
402
  # print(f"VSim score between two audios: {sim:.4f} (-1.0, 1.0).")
403
- sim_list.append(sim)
404
 
405
- return sim_list
 
2
  import os
3
  import random
4
  import string
5
+ from pathlib import Path
6
 
7
  import torch
8
  import torch.nn.functional as F
 
321
  from zhon.hanzi import punctuation
322
 
323
  punctuation_all = punctuation + string.punctuation
324
+ wer_results = []
325
 
326
  from jiwer import compute_measures
327
 
 
336
  for segment in segments:
337
  hypo = hypo + " " + segment.text
338
 
339
+ raw_truth = truth
340
+ raw_hypo = hypo
341
 
342
  for x in punctuation_all:
343
  truth = truth.replace(x, "")
 
361
  # dele = measures["deletions"] / len(ref_list)
362
  # inse = measures["insertions"] / len(ref_list)
363
 
364
+ wer_results.append(
365
+ {
366
+ "wav": Path(gen_wav).stem,
367
+ "truth": raw_truth,
368
+ "hypo": raw_hypo,
369
+ "wer": wer,
370
+ }
371
+ )
372
 
373
+ return wer_results
374
 
375
 
376
  # SIM Evaluation
 
389
  model = model.cuda(device)
390
  model.eval()
391
 
392
+ sims = []
393
  for wav1, wav2, truth in tqdm(test_set):
394
  wav1, sr1 = torchaudio.load(wav1)
395
  wav2, sr2 = torchaudio.load(wav2)
 
408
 
409
  sim = F.cosine_similarity(emb1, emb2)[0].item()
410
  # print(f"VSim score between two audios: {sim:.4f} (-1.0, 1.0).")
411
+ sims.append(sim)
412
 
413
+ return sims
src/f5_tts/infer/README.md CHANGED
@@ -12,6 +12,8 @@ To avoid possible inference failures, make sure you have seen through the follow
12
  - Uppercased letters will be uttered letter by letter, so use lowercased letters for normal words.
13
  - Add some spaces (blank: " ") or punctuations (e.g. "," ".") to explicitly introduce some pauses.
14
  - Preprocess numbers to Chinese letters if you want to have them read in Chinese, otherwise in English.
 
 
15
 
16
 
17
  ## Gradio App
@@ -62,6 +64,9 @@ f5-tts_infer-cli \
62
  # Choose Vocoder
63
  f5-tts_infer-cli --vocoder_name bigvgan --load_vocoder_from_local --ckpt_file <YOUR_CKPT_PATH, eg:ckpts/F5TTS_Base_bigvgan/model_1250000.pt>
64
  f5-tts_infer-cli --vocoder_name vocos --load_vocoder_from_local --ckpt_file <YOUR_CKPT_PATH, eg:ckpts/F5TTS_Base/model_1200000.safetensors>
 
 
 
65
  ```
66
 
67
  And a `.toml` file would help with more flexible usage.
 
12
  - Uppercased letters will be uttered letter by letter, so use lowercased letters for normal words.
13
  - Add some spaces (blank: " ") or punctuations (e.g. "," ".") to explicitly introduce some pauses.
14
  - Preprocess numbers to Chinese letters if you want to have them read in Chinese, otherwise in English.
15
+ - If the generation output is blank (pure silence), check for ffmpeg installation (various tutorials online, blogs, videos, etc.).
16
+ - Try turn off use_ema if using an early-stage finetuned checkpoint (which goes just few updates).
17
 
18
 
19
  ## Gradio App
 
64
  # Choose Vocoder
65
  f5-tts_infer-cli --vocoder_name bigvgan --load_vocoder_from_local --ckpt_file <YOUR_CKPT_PATH, eg:ckpts/F5TTS_Base_bigvgan/model_1250000.pt>
66
  f5-tts_infer-cli --vocoder_name vocos --load_vocoder_from_local --ckpt_file <YOUR_CKPT_PATH, eg:ckpts/F5TTS_Base/model_1200000.safetensors>
67
+
68
+ # More instructions
69
+ f5-tts_infer-cli --help
70
  ```
71
 
72
  And a `.toml` file would help with more flexible usage.
src/f5_tts/infer/SHARED.md CHANGED
@@ -16,59 +16,131 @@
16
  <!-- omit in toc -->
17
  ### Supported Languages
18
  - [Multilingual](#multilingual)
19
- - [F5-TTS Base @ pretrain @ zh \& en](#f5-tts-base--pretrain--zh--en)
20
- - [Mandarin](#mandarin)
21
- - [Japanese](#japanese)
22
- - [F5-TTS Base @ pretrain/finetune @ ja](#f5-tts-base--pretrainfinetune--ja)
23
  - [English](#english)
 
 
24
  - [French](#french)
25
- - [French LibriVox @ finetune @ fr](#french-librivox--finetune--fr)
 
 
 
 
 
 
 
 
 
26
 
27
 
28
  ## Multilingual
29
 
30
- #### F5-TTS Base @ pretrain @ zh & en
31
  |Model|🤗Hugging Face|Data (Hours)|Model License|
32
  |:---:|:------------:|:-----------:|:-------------:|
33
  |F5-TTS Base|[ckpt & vocab](https://huggingface.co/SWivid/F5-TTS/tree/main/F5TTS_Base)|[Emilia 95K zh&en](https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07)|cc-by-nc-4.0|
34
 
35
  ```bash
36
- MODEL_CKPT: hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors
37
- VOCAB_FILE: hf://SWivid/F5-TTS/F5TTS_Base/vocab.txt
 
38
  ```
39
 
40
  *Other infos, e.g. Author info, Github repo, Link to some sampled results, Usage instruction, Tutorial (Blog, Video, etc.) ...*
41
 
42
 
43
- ## Mandarin
44
 
45
- ## Japanese
46
 
47
- #### F5-TTS Base @ pretrain/finetune @ ja
48
- |Model|🤗Hugging Face|Data (Hours)|Model License|
 
 
49
  |:---:|:------------:|:-----------:|:-------------:|
50
- |F5-TTS Base|[ckpt & vocab](https://huggingface.co/Jmica/F5TTS/tree/main/JA_8500000)|[Emilia 1.7k JA](https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07) & [Galgame Dataset 5.4k](https://huggingface.co/datasets/OOPPEENN/Galgame_Dataset)|cc-by-nc-4.0|
51
 
52
  ```bash
53
- MODEL_CKPT: hf://Jmica/F5TTS/JA_8500000/model_8499660.pt
54
- VOCAB_FILE: hf://Jmica/F5TTS/JA_8500000/vocab_updated.txt
 
55
  ```
56
 
57
- ## English
58
-
59
 
60
  ## French
61
 
62
- #### French LibriVox @ finetune @ fr
63
  |Model|🤗Hugging Face|Data (Hours)|Model License|
64
  |:---:|:------------:|:-----------:|:-------------:|
65
- |F5-TTS French|[ckpt & vocab](https://huggingface.co/RASPIAUDIO/F5-French-MixedSpeakers-reduced)|[LibriVox](https://librivox.org/)|cc-by-nc-4.0|
66
 
67
  ```bash
68
- MODEL_CKPT: hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/model_last_reduced.pt
69
- VOCAB_FILE: hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/vocab.txt
 
70
  ```
71
 
72
  - [Online Inference with Hugging Face Space](https://huggingface.co/spaces/RASPIAUDIO/f5-tts_french).
73
  - [Tutorial video to train a new language model](https://www.youtube.com/watch?v=UO4usaOojys).
74
  - [Discussion about this training can be found here](https://github.com/SWivid/F5-TTS/issues/434).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  <!-- omit in toc -->
17
  ### Supported Languages
18
  - [Multilingual](#multilingual)
19
+ - [F5-TTS Base @ zh \& en @ F5-TTS](#f5-tts-base--zh--en--f5-tts)
 
 
 
20
  - [English](#english)
21
+ - [Finnish](#finnish)
22
+ - [F5-TTS Base @ fi @ AsmoKoskinen](#f5-tts-base--fi--asmokoskinen)
23
  - [French](#french)
24
+ - [F5-TTS Base @ fr @ RASPIAUDIO](#f5-tts-base--fr--raspiaudio)
25
+ - [Hindi](#hindi)
26
+ - [F5-TTS Small @ hi @ SPRINGLab](#f5-tts-small--hi--springlab)
27
+ - [Italian](#italian)
28
+ - [F5-TTS Base @ it @ alien79](#f5-tts-base--it--alien79)
29
+ - [Japanese](#japanese)
30
+ - [F5-TTS Base @ ja @ Jmica](#f5-tts-base--ja--jmica)
31
+ - [Mandarin](#mandarin)
32
+ - [Spanish](#spanish)
33
+ - [F5-TTS Base @ es @ jpgallegoar](#f5-tts-base--es--jpgallegoar)
34
 
35
 
36
  ## Multilingual
37
 
38
+ #### F5-TTS Base @ zh & en @ F5-TTS
39
  |Model|🤗Hugging Face|Data (Hours)|Model License|
40
  |:---:|:------------:|:-----------:|:-------------:|
41
  |F5-TTS Base|[ckpt & vocab](https://huggingface.co/SWivid/F5-TTS/tree/main/F5TTS_Base)|[Emilia 95K zh&en](https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07)|cc-by-nc-4.0|
42
 
43
  ```bash
44
+ Model: hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors
45
+ Vocab: hf://SWivid/F5-TTS/F5TTS_Base/vocab.txt
46
+ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
47
  ```
48
 
49
  *Other infos, e.g. Author info, Github repo, Link to some sampled results, Usage instruction, Tutorial (Blog, Video, etc.) ...*
50
 
51
 
52
+ ## English
53
 
 
54
 
55
+ ## Finnish
56
+
57
+ #### F5-TTS Base @ fi @ AsmoKoskinen
58
+ |Model|🤗Hugging Face|Data|Model License|
59
  |:---:|:------------:|:-----------:|:-------------:|
60
+ |F5-TTS Base|[ckpt & vocab](https://huggingface.co/AsmoKoskinen/F5-TTS_Finnish_Model)|[Common Voice](https://huggingface.co/datasets/mozilla-foundation/common_voice_17_0), [Vox Populi](https://huggingface.co/datasets/facebook/voxpopuli)|cc-by-nc-4.0|
61
 
62
  ```bash
63
+ Model: hf://AsmoKoskinen/F5-TTS_Finnish_Model/model_common_voice_fi_vox_populi_fi_20241206.safetensors
64
+ Vocab: hf://AsmoKoskinen/F5-TTS_Finnish_Model/vocab.txt
65
+ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
66
  ```
67
 
 
 
68
 
69
  ## French
70
 
71
+ #### F5-TTS Base @ fr @ RASPIAUDIO
72
  |Model|🤗Hugging Face|Data (Hours)|Model License|
73
  |:---:|:------------:|:-----------:|:-------------:|
74
+ |F5-TTS Base|[ckpt & vocab](https://huggingface.co/RASPIAUDIO/F5-French-MixedSpeakers-reduced)|[LibriVox](https://librivox.org/)|cc-by-nc-4.0|
75
 
76
  ```bash
77
+ Model: hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/model_last_reduced.pt
78
+ Vocab: hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/vocab.txt
79
+ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
80
  ```
81
 
82
  - [Online Inference with Hugging Face Space](https://huggingface.co/spaces/RASPIAUDIO/f5-tts_french).
83
  - [Tutorial video to train a new language model](https://www.youtube.com/watch?v=UO4usaOojys).
84
  - [Discussion about this training can be found here](https://github.com/SWivid/F5-TTS/issues/434).
85
+
86
+
87
+ ## Hindi
88
+
89
+ #### F5-TTS Small @ hi @ SPRINGLab
90
+ |Model|🤗Hugging Face|Data (Hours)|Model License|
91
+ |:---:|:------------:|:-----------:|:-------------:|
92
+ |F5-TTS Small|[ckpt & vocab](https://huggingface.co/SPRINGLab/F5-Hindi-24KHz)|[IndicTTS Hi](https://huggingface.co/datasets/SPRINGLab/IndicTTS-Hindi) & [IndicVoices-R Hi](https://huggingface.co/datasets/SPRINGLab/IndicVoices-R_Hindi) |cc-by-4.0|
93
+
94
+ ```bash
95
+ Model: hf://SPRINGLab/F5-Hindi-24KHz/model_2500000.safetensors
96
+ Vocab: hf://SPRINGLab/F5-Hindi-24KHz/vocab.txt
97
+ Config: {"dim": 768, "depth": 18, "heads": 12, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
98
+ ```
99
+
100
+ - Authors: SPRING Lab, Indian Institute of Technology, Madras
101
+ - Website: https://asr.iitm.ac.in/
102
+
103
+
104
+ ## Italian
105
+
106
+ #### F5-TTS Base @ it @ alien79
107
+ |Model|🤗Hugging Face|Data|Model License|
108
+ |:---:|:------------:|:-----------:|:-------------:|
109
+ |F5-TTS Base|[ckpt & vocab](https://huggingface.co/alien79/F5-TTS-italian)|[ylacombe/cml-tts](https://huggingface.co/datasets/ylacombe/cml-tts) |cc-by-nc-4.0|
110
+
111
+ ```bash
112
+ Model: hf://alien79/F5-TTS-italian/model_159600.safetensors
113
+ Vocab: hf://alien79/F5-TTS-italian/vocab.txt
114
+ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
115
+ ```
116
+
117
+ - Trained by [Mithril Man](https://github.com/MithrilMan)
118
+ - Model details on [hf project home](https://huggingface.co/alien79/F5-TTS-italian)
119
+ - Open to collaborations to further improve the model
120
+
121
+
122
+ ## Japanese
123
+
124
+ #### F5-TTS Base @ ja @ Jmica
125
+ |Model|🤗Hugging Face|Data (Hours)|Model License|
126
+ |:---:|:------------:|:-----------:|:-------------:|
127
+ |F5-TTS Base|[ckpt & vocab](https://huggingface.co/Jmica/F5TTS/tree/main/JA_8500000)|[Emilia 1.7k JA](https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07) & [Galgame Dataset 5.4k](https://huggingface.co/datasets/OOPPEENN/Galgame_Dataset)|cc-by-nc-4.0|
128
+
129
+ ```bash
130
+ Model: hf://Jmica/F5TTS/JA_8500000/model_8499660.pt
131
+ Vocab: hf://Jmica/F5TTS/JA_8500000/vocab_updated.txt
132
+ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
133
+ ```
134
+
135
+
136
+ ## Mandarin
137
+
138
+
139
+ ## Spanish
140
+
141
+ #### F5-TTS Base @ es @ jpgallegoar
142
+ |Model|🤗Hugging Face|Data (Hours)|Model License|
143
+ |:---:|:------------:|:-----------:|:-------------:|
144
+ |F5-TTS Base|[ckpt & vocab](https://huggingface.co/jpgallegoar/F5-Spanish)|[Voxpopuli](https://huggingface.co/datasets/facebook/voxpopuli) & Crowdsourced & TEDx, 218 hours|cc0-1.0|
145
+
146
+ - @jpgallegoar [GitHub repo](https://github.com/jpgallegoar/Spanish-F5), Jupyter Notebook and Gradio usage for Spanish model.
src/f5_tts/infer/examples/basic/basic.toml CHANGED
@@ -8,4 +8,4 @@ gen_text = "I don't really care what you call me. I've been a silent spectator,
8
  gen_file = ""
9
  remove_silence = false
10
  output_dir = "tests"
11
- output_file = "infer_cli_out.wav"
 
8
  gen_file = ""
9
  remove_silence = false
10
  output_dir = "tests"
11
+ output_file = "infer_cli_basic.wav"
src/f5_tts/infer/examples/multi/story.toml CHANGED
@@ -8,6 +8,7 @@ gen_text = ""
8
  gen_file = "infer/examples/multi/story.txt"
9
  remove_silence = true
10
  output_dir = "tests"
 
11
 
12
  [voices.town]
13
  ref_audio = "infer/examples/multi/town.flac"
 
8
  gen_file = "infer/examples/multi/story.txt"
9
  remove_silence = true
10
  output_dir = "tests"
11
+ output_file = "infer_cli_story.wav"
12
 
13
  [voices.town]
14
  ref_audio = "infer/examples/multi/town.flac"
src/f5_tts/infer/infer_cli.py CHANGED
@@ -2,6 +2,7 @@ import argparse
2
  import codecs
3
  import os
4
  import re
 
5
  from importlib.resources import files
6
  from pathlib import Path
7
 
@@ -9,8 +10,17 @@ import numpy as np
9
  import soundfile as sf
10
  import tomli
11
  from cached_path import cached_path
 
12
 
13
  from f5_tts.infer.utils_infer import (
 
 
 
 
 
 
 
 
14
  infer_process,
15
  load_model,
16
  load_vocoder,
@@ -19,6 +29,7 @@ from f5_tts.infer.utils_infer import (
19
  )
20
  from f5_tts.model import DiT, UNetT
21
 
 
22
  parser = argparse.ArgumentParser(
23
  prog="python3 infer-cli.py",
24
  description="Commandline interface for E2/F5 TTS with Advanced Batch Processing.",
@@ -27,74 +38,168 @@ parser = argparse.ArgumentParser(
27
  parser.add_argument(
28
  "-c",
29
  "--config",
30
- help="Configuration file. Default=infer/examples/basic/basic.toml",
31
  default=os.path.join(files("f5_tts").joinpath("infer/examples/basic"), "basic.toml"),
 
32
  )
 
 
 
 
33
  parser.add_argument(
34
  "-m",
35
  "--model",
36
- help="F5-TTS | E2-TTS",
 
 
 
 
 
 
 
37
  )
38
  parser.add_argument(
39
  "-p",
40
  "--ckpt_file",
41
- help="The Checkpoint .pt",
 
42
  )
43
  parser.add_argument(
44
  "-v",
45
  "--vocab_file",
46
- help="The vocab .txt",
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  )
48
- parser.add_argument("-r", "--ref_audio", type=str, help="Reference audio file < 15 seconds.")
49
- parser.add_argument("-s", "--ref_text", type=str, default="666", help="Subtitle for the reference audio.")
50
  parser.add_argument(
51
  "-t",
52
  "--gen_text",
53
  type=str,
54
- help="Text to generate.",
55
  )
56
  parser.add_argument(
57
  "-f",
58
  "--gen_file",
59
  type=str,
60
- help="File with text to generate. Ignores --gen_text",
61
  )
62
  parser.add_argument(
63
  "-o",
64
  "--output_dir",
65
  type=str,
66
- help="Path to output folder..",
67
  )
68
  parser.add_argument(
69
  "-w",
70
  "--output_file",
71
  type=str,
72
- help="Filename of output file..",
 
 
 
 
 
73
  )
74
  parser.add_argument(
75
  "--remove_silence",
76
- help="Remove silence.",
 
77
  )
78
- parser.add_argument("--vocoder_name", type=str, default="vocos", choices=["vocos", "bigvgan"], help="vocoder name")
79
  parser.add_argument(
80
  "--load_vocoder_from_local",
81
  action="store_true",
82
- help="load vocoder from local. Default: ../checkpoints/charactr/vocos-mel-24khz",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  )
84
  parser.add_argument(
85
  "--speed",
86
  type=float,
87
- default=1.0,
88
- help="Adjust the speed of the audio generation (default: 1.0)",
 
 
 
 
89
  )
90
  args = parser.parse_args()
91
 
 
 
 
92
  config = tomli.load(open(args.config, "rb"))
93
 
94
- ref_audio = args.ref_audio if args.ref_audio else config["ref_audio"]
95
- ref_text = args.ref_text if args.ref_text != "666" else config["ref_text"]
96
- gen_text = args.gen_text if args.gen_text else config["gen_text"]
97
- gen_file = args.gen_file if args.gen_file else config["gen_file"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
  # patches for pip pkg user
100
  if "infer/examples/" in ref_audio:
@@ -107,34 +212,39 @@ if "voices" in config:
107
  if "infer/examples/" in voice_ref_audio:
108
  config["voices"][voice]["ref_audio"] = str(files("f5_tts").joinpath(f"{voice_ref_audio}"))
109
 
 
 
 
110
  if gen_file:
111
  gen_text = codecs.open(gen_file, "r", "utf-8").read()
112
- output_dir = args.output_dir if args.output_dir else config["output_dir"]
113
- output_file = args.output_file if args.output_file else config["output_file"]
114
- model = args.model if args.model else config["model"]
115
- ckpt_file = args.ckpt_file if args.ckpt_file else ""
116
- vocab_file = args.vocab_file if args.vocab_file else ""
117
- remove_silence = args.remove_silence if args.remove_silence else config["remove_silence"]
118
- speed = args.speed
119
 
120
  wave_path = Path(output_dir) / output_file
121
  # spectrogram_path = Path(output_dir) / "infer_cli_out.png"
 
 
 
 
 
 
 
122
 
123
- vocoder_name = args.vocoder_name
124
- mel_spec_type = args.vocoder_name
125
  if vocoder_name == "vocos":
126
  vocoder_local_path = "../checkpoints/vocos-mel-24khz"
127
  elif vocoder_name == "bigvgan":
128
  vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
129
 
130
- vocoder = load_vocoder(vocoder_name=mel_spec_type, is_local=args.load_vocoder_from_local, local_path=vocoder_local_path)
 
131
 
 
132
 
133
- # load models
134
  if model == "F5-TTS":
135
  model_cls = DiT
136
- model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
137
- if ckpt_file == "":
138
  if vocoder_name == "vocos":
139
  repo_name = "F5-TTS"
140
  exp_name = "F5TTS_Base"
@@ -148,22 +258,25 @@ if model == "F5-TTS":
148
  ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.pt"))
149
 
150
  elif model == "E2-TTS":
151
- assert vocoder_name == "vocos", "E2-TTS only supports vocoder vocos"
 
152
  model_cls = UNetT
153
  model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
154
- if ckpt_file == "":
155
  repo_name = "E2-TTS"
156
  exp_name = "E2TTS_Base"
157
  ckpt_step = 1200000
158
  ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
159
  # ckpt_file = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path
160
 
161
-
162
  print(f"Using {model}...")
163
- ema_model = load_model(model_cls, model_cfg, ckpt_file, mel_spec_type=mel_spec_type, vocab_file=vocab_file)
 
164
 
 
165
 
166
- def main_process(ref_audio, ref_text, text_gen, model_obj, mel_spec_type, remove_silence, speed):
 
167
  main_voice = {"ref_audio": ref_audio, "ref_text": ref_text}
168
  if "voices" not in config:
169
  voices = {"main": main_voice}
@@ -171,16 +284,16 @@ def main_process(ref_audio, ref_text, text_gen, model_obj, mel_spec_type, remove
171
  voices = config["voices"]
172
  voices["main"] = main_voice
173
  for voice in voices:
 
 
174
  voices[voice]["ref_audio"], voices[voice]["ref_text"] = preprocess_ref_audio_text(
175
  voices[voice]["ref_audio"], voices[voice]["ref_text"]
176
  )
177
- print("Voice:", voice)
178
- print("Ref_audio:", voices[voice]["ref_audio"])
179
- print("Ref_text:", voices[voice]["ref_text"])
180
 
181
  generated_audio_segments = []
182
  reg1 = r"(?=\[\w+\])"
183
- chunks = re.split(reg1, text_gen)
184
  reg2 = r"\[(\w+)\]"
185
  for text in chunks:
186
  if not text.strip():
@@ -195,14 +308,35 @@ def main_process(ref_audio, ref_text, text_gen, model_obj, mel_spec_type, remove
195
  print(f"Voice {voice} not found, using main.")
196
  voice = "main"
197
  text = re.sub(reg2, "", text)
198
- gen_text = text.strip()
199
- ref_audio = voices[voice]["ref_audio"]
200
- ref_text = voices[voice]["ref_text"]
201
  print(f"Voice: {voice}")
202
- audio, final_sample_rate, spectragram = infer_process(
203
- ref_audio, ref_text, gen_text, model_obj, vocoder, mel_spec_type=mel_spec_type, speed=speed
 
 
 
 
 
 
 
 
 
 
 
 
204
  )
205
- generated_audio_segments.append(audio)
 
 
 
 
 
 
 
 
 
206
 
207
  if generated_audio_segments:
208
  final_wave = np.concatenate(generated_audio_segments)
@@ -218,9 +352,5 @@ def main_process(ref_audio, ref_text, text_gen, model_obj, mel_spec_type, remove
218
  print(f.name)
219
 
220
 
221
- def main():
222
- main_process(ref_audio, ref_text, gen_text, ema_model, mel_spec_type, remove_silence, speed)
223
-
224
-
225
  if __name__ == "__main__":
226
  main()
 
2
  import codecs
3
  import os
4
  import re
5
+ from datetime import datetime
6
  from importlib.resources import files
7
  from pathlib import Path
8
 
 
10
  import soundfile as sf
11
  import tomli
12
  from cached_path import cached_path
13
+ from omegaconf import OmegaConf
14
 
15
  from f5_tts.infer.utils_infer import (
16
+ mel_spec_type,
17
+ target_rms,
18
+ cross_fade_duration,
19
+ nfe_step,
20
+ cfg_strength,
21
+ sway_sampling_coef,
22
+ speed,
23
+ fix_duration,
24
  infer_process,
25
  load_model,
26
  load_vocoder,
 
29
  )
30
  from f5_tts.model import DiT, UNetT
31
 
32
+
33
  parser = argparse.ArgumentParser(
34
  prog="python3 infer-cli.py",
35
  description="Commandline interface for E2/F5 TTS with Advanced Batch Processing.",
 
38
  parser.add_argument(
39
  "-c",
40
  "--config",
41
+ type=str,
42
  default=os.path.join(files("f5_tts").joinpath("infer/examples/basic"), "basic.toml"),
43
+ help="The configuration file, default see infer/examples/basic/basic.toml",
44
  )
45
+
46
+
47
+ # Note. Not to provide default value here in order to read default from config file
48
+
49
  parser.add_argument(
50
  "-m",
51
  "--model",
52
+ type=str,
53
+ help="The model name: F5-TTS | E2-TTS",
54
+ )
55
+ parser.add_argument(
56
+ "-mc",
57
+ "--model_cfg",
58
+ type=str,
59
+ help="The path to F5-TTS model config file .yaml",
60
  )
61
  parser.add_argument(
62
  "-p",
63
  "--ckpt_file",
64
+ type=str,
65
+ help="The path to model checkpoint .pt, leave blank to use default",
66
  )
67
  parser.add_argument(
68
  "-v",
69
  "--vocab_file",
70
+ type=str,
71
+ help="The path to vocab file .txt, leave blank to use default",
72
+ )
73
+ parser.add_argument(
74
+ "-r",
75
+ "--ref_audio",
76
+ type=str,
77
+ help="The reference audio file.",
78
+ )
79
+ parser.add_argument(
80
+ "-s",
81
+ "--ref_text",
82
+ type=str,
83
+ help="The transcript/subtitle for the reference audio",
84
  )
 
 
85
  parser.add_argument(
86
  "-t",
87
  "--gen_text",
88
  type=str,
89
+ help="The text to make model synthesize a speech",
90
  )
91
  parser.add_argument(
92
  "-f",
93
  "--gen_file",
94
  type=str,
95
+ help="The file with text to generate, will ignore --gen_text",
96
  )
97
  parser.add_argument(
98
  "-o",
99
  "--output_dir",
100
  type=str,
101
+ help="The path to output folder",
102
  )
103
  parser.add_argument(
104
  "-w",
105
  "--output_file",
106
  type=str,
107
+ help="The name of output file",
108
+ )
109
+ parser.add_argument(
110
+ "--save_chunk",
111
+ action="store_true",
112
+ help="To save each audio chunks during inference",
113
  )
114
  parser.add_argument(
115
  "--remove_silence",
116
+ action="store_true",
117
+ help="To remove long silence found in ouput",
118
  )
 
119
  parser.add_argument(
120
  "--load_vocoder_from_local",
121
  action="store_true",
122
+ help="To load vocoder from local dir, default to ../checkpoints/vocos-mel-24khz",
123
+ )
124
+ parser.add_argument(
125
+ "--vocoder_name",
126
+ type=str,
127
+ choices=["vocos", "bigvgan"],
128
+ help=f"Used vocoder name: vocos | bigvgan, default {mel_spec_type}",
129
+ )
130
+ parser.add_argument(
131
+ "--target_rms",
132
+ type=float,
133
+ help=f"Target output speech loudness normalization value, default {target_rms}",
134
+ )
135
+ parser.add_argument(
136
+ "--cross_fade_duration",
137
+ type=float,
138
+ help=f"Duration of cross-fade between audio segments in seconds, default {cross_fade_duration}",
139
+ )
140
+ parser.add_argument(
141
+ "--nfe_step",
142
+ type=int,
143
+ help=f"The number of function evaluation (denoising steps), default {nfe_step}",
144
+ )
145
+ parser.add_argument(
146
+ "--cfg_strength",
147
+ type=float,
148
+ help=f"Classifier-free guidance strength, default {cfg_strength}",
149
+ )
150
+ parser.add_argument(
151
+ "--sway_sampling_coef",
152
+ type=float,
153
+ help=f"Sway Sampling coefficient, default {sway_sampling_coef}",
154
  )
155
  parser.add_argument(
156
  "--speed",
157
  type=float,
158
+ help=f"The speed of the generated audio, default {speed}",
159
+ )
160
+ parser.add_argument(
161
+ "--fix_duration",
162
+ type=float,
163
+ help=f"Fix the total duration (ref and gen audios) in seconds, default {fix_duration}",
164
  )
165
  args = parser.parse_args()
166
 
167
+
168
+ # config file
169
+
170
  config = tomli.load(open(args.config, "rb"))
171
 
172
+
173
+ # command-line interface parameters
174
+
175
+ model = args.model or config.get("model", "F5-TTS")
176
+ model_cfg = args.model_cfg or config.get("model_cfg", str(files("f5_tts").joinpath("configs/F5TTS_Base_train.yaml")))
177
+ ckpt_file = args.ckpt_file or config.get("ckpt_file", "")
178
+ vocab_file = args.vocab_file or config.get("vocab_file", "")
179
+
180
+ ref_audio = args.ref_audio or config.get("ref_audio", "infer/examples/basic/basic_ref_en.wav")
181
+ ref_text = args.ref_text or config.get("ref_text", "Some call me nature, others call me mother nature.")
182
+ gen_text = args.gen_text or config.get("gen_text", "Here we generate something just for test.")
183
+ gen_file = args.gen_file or config.get("gen_file", "")
184
+
185
+ output_dir = args.output_dir or config.get("output_dir", "tests")
186
+ output_file = args.output_file or config.get(
187
+ "output_file", f"infer_cli_{datetime.now().strftime(r'%Y%m%d_%H%M%S')}.wav"
188
+ )
189
+
190
+ save_chunk = args.save_chunk or config.get("save_chunk", False)
191
+ remove_silence = args.remove_silence or config.get("remove_silence", False)
192
+ load_vocoder_from_local = args.load_vocoder_from_local or config.get("load_vocoder_from_local", False)
193
+
194
+ vocoder_name = args.vocoder_name or config.get("vocoder_name", mel_spec_type)
195
+ target_rms = args.target_rms or config.get("target_rms", target_rms)
196
+ cross_fade_duration = args.cross_fade_duration or config.get("cross_fade_duration", cross_fade_duration)
197
+ nfe_step = args.nfe_step or config.get("nfe_step", nfe_step)
198
+ cfg_strength = args.cfg_strength or config.get("cfg_strength", cfg_strength)
199
+ sway_sampling_coef = args.sway_sampling_coef or config.get("sway_sampling_coef", sway_sampling_coef)
200
+ speed = args.speed or config.get("speed", speed)
201
+ fix_duration = args.fix_duration or config.get("fix_duration", fix_duration)
202
+
203
 
204
  # patches for pip pkg user
205
  if "infer/examples/" in ref_audio:
 
212
  if "infer/examples/" in voice_ref_audio:
213
  config["voices"][voice]["ref_audio"] = str(files("f5_tts").joinpath(f"{voice_ref_audio}"))
214
 
215
+
216
+ # ignore gen_text if gen_file provided
217
+
218
  if gen_file:
219
  gen_text = codecs.open(gen_file, "r", "utf-8").read()
220
+
221
+
222
+ # output path
 
 
 
 
223
 
224
  wave_path = Path(output_dir) / output_file
225
  # spectrogram_path = Path(output_dir) / "infer_cli_out.png"
226
+ if save_chunk:
227
+ output_chunk_dir = os.path.join(output_dir, f"{Path(output_file).stem}_chunks")
228
+ if not os.path.exists(output_chunk_dir):
229
+ os.makedirs(output_chunk_dir)
230
+
231
+
232
+ # load vocoder
233
 
 
 
234
  if vocoder_name == "vocos":
235
  vocoder_local_path = "../checkpoints/vocos-mel-24khz"
236
  elif vocoder_name == "bigvgan":
237
  vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
238
 
239
+ vocoder = load_vocoder(vocoder_name=vocoder_name, is_local=load_vocoder_from_local, local_path=vocoder_local_path)
240
+
241
 
242
+ # load TTS model
243
 
 
244
  if model == "F5-TTS":
245
  model_cls = DiT
246
+ model_cfg = OmegaConf.load(model_cfg).model.arch
247
+ if not ckpt_file: # path not specified, download from repo
248
  if vocoder_name == "vocos":
249
  repo_name = "F5-TTS"
250
  exp_name = "F5TTS_Base"
 
258
  ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.pt"))
259
 
260
  elif model == "E2-TTS":
261
+ assert args.model_cfg is None, "E2-TTS does not support custom model_cfg yet"
262
+ assert vocoder_name == "vocos", "E2-TTS only supports vocoder vocos yet"
263
  model_cls = UNetT
264
  model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
265
+ if not ckpt_file: # path not specified, download from repo
266
  repo_name = "E2-TTS"
267
  exp_name = "E2TTS_Base"
268
  ckpt_step = 1200000
269
  ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
270
  # ckpt_file = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path
271
 
 
272
  print(f"Using {model}...")
273
+ ema_model = load_model(model_cls, model_cfg, ckpt_file, mel_spec_type=vocoder_name, vocab_file=vocab_file)
274
+
275
 
276
+ # inference process
277
 
278
+
279
+ def main():
280
  main_voice = {"ref_audio": ref_audio, "ref_text": ref_text}
281
  if "voices" not in config:
282
  voices = {"main": main_voice}
 
284
  voices = config["voices"]
285
  voices["main"] = main_voice
286
  for voice in voices:
287
+ print("Voice:", voice)
288
+ print("ref_audio ", voices[voice]["ref_audio"])
289
  voices[voice]["ref_audio"], voices[voice]["ref_text"] = preprocess_ref_audio_text(
290
  voices[voice]["ref_audio"], voices[voice]["ref_text"]
291
  )
292
+ print("ref_audio_", voices[voice]["ref_audio"], "\n\n")
 
 
293
 
294
  generated_audio_segments = []
295
  reg1 = r"(?=\[\w+\])"
296
+ chunks = re.split(reg1, gen_text)
297
  reg2 = r"\[(\w+)\]"
298
  for text in chunks:
299
  if not text.strip():
 
308
  print(f"Voice {voice} not found, using main.")
309
  voice = "main"
310
  text = re.sub(reg2, "", text)
311
+ ref_audio_ = voices[voice]["ref_audio"]
312
+ ref_text_ = voices[voice]["ref_text"]
313
+ gen_text_ = text.strip()
314
  print(f"Voice: {voice}")
315
+ audio_segment, final_sample_rate, spectragram = infer_process(
316
+ ref_audio_,
317
+ ref_text_,
318
+ gen_text_,
319
+ ema_model,
320
+ vocoder,
321
+ mel_spec_type=vocoder_name,
322
+ target_rms=target_rms,
323
+ cross_fade_duration=cross_fade_duration,
324
+ nfe_step=nfe_step,
325
+ cfg_strength=cfg_strength,
326
+ sway_sampling_coef=sway_sampling_coef,
327
+ speed=speed,
328
+ fix_duration=fix_duration,
329
  )
330
+ generated_audio_segments.append(audio_segment)
331
+
332
+ if save_chunk:
333
+ if len(gen_text_) > 200:
334
+ gen_text_ = gen_text_[:200] + " ... "
335
+ sf.write(
336
+ os.path.join(output_chunk_dir, f"{len(generated_audio_segments)-1}_{gen_text_}.wav"),
337
+ audio_segment,
338
+ final_sample_rate,
339
+ )
340
 
341
  if generated_audio_segments:
342
  final_wave = np.concatenate(generated_audio_segments)
 
352
  print(f.name)
353
 
354
 
 
 
 
 
355
  if __name__ == "__main__":
356
  main()
src/f5_tts/infer/utils_infer.py CHANGED
@@ -138,7 +138,11 @@ asr_pipe = None
138
  def initialize_asr_pipeline(device: str = device, dtype=None):
139
  if dtype is None:
140
  dtype = (
141
- torch.float16 if "cuda" in device and torch.cuda.get_device_properties(device).major >= 6 else torch.float32
 
 
 
 
142
  )
143
  global asr_pipe
144
  asr_pipe = pipeline(
@@ -171,7 +175,11 @@ def transcribe(ref_audio, language=None):
171
  def load_checkpoint(model, ckpt_path, device: str, dtype=None, use_ema=True):
172
  if dtype is None:
173
  dtype = (
174
- torch.float16 if "cuda" in device and torch.cuda.get_device_properties(device).major >= 6 else torch.float32
 
 
 
 
175
  )
176
  model = model.to(dtype)
177
 
@@ -338,7 +346,7 @@ def preprocess_ref_audio_text(ref_audio_orig, ref_text, clip_short=True, show_in
338
  else:
339
  ref_text += ". "
340
 
341
- print("ref_text ", ref_text)
342
 
343
  return ref_audio, ref_text
344
 
@@ -370,6 +378,7 @@ def infer_process(
370
  gen_text_batches = chunk_text(gen_text, max_chars=max_chars)
371
  for i, gen_text in enumerate(gen_text_batches):
372
  print(f"gen_text {i}", gen_text)
 
373
 
374
  show_info(f"Generating audio in {len(gen_text_batches)} batches...")
375
  return infer_batch_process(
 
138
  def initialize_asr_pipeline(device: str = device, dtype=None):
139
  if dtype is None:
140
  dtype = (
141
+ torch.float16
142
+ if "cuda" in device
143
+ and torch.cuda.get_device_properties(device).major >= 6
144
+ and not torch.cuda.get_device_name().endswith("[ZLUDA]")
145
+ else torch.float32
146
  )
147
  global asr_pipe
148
  asr_pipe = pipeline(
 
175
  def load_checkpoint(model, ckpt_path, device: str, dtype=None, use_ema=True):
176
  if dtype is None:
177
  dtype = (
178
+ torch.float16
179
+ if "cuda" in device
180
+ and torch.cuda.get_device_properties(device).major >= 6
181
+ and not torch.cuda.get_device_name().endswith("[ZLUDA]")
182
+ else torch.float32
183
  )
184
  model = model.to(dtype)
185
 
 
346
  else:
347
  ref_text += ". "
348
 
349
+ print("\nref_text ", ref_text)
350
 
351
  return ref_audio, ref_text
352
 
 
378
  gen_text_batches = chunk_text(gen_text, max_chars=max_chars)
379
  for i, gen_text in enumerate(gen_text_batches):
380
  print(f"gen_text {i}", gen_text)
381
+ print("\n")
382
 
383
  show_info(f"Generating audio in {len(gen_text_batches)} batches...")
384
  return infer_batch_process(
src/f5_tts/model/backbones/dit.py CHANGED
@@ -105,6 +105,7 @@ class DiT(nn.Module):
105
  text_dim=None,
106
  conv_layers=0,
107
  long_skip_connection=False,
 
108
  ):
109
  super().__init__()
110
 
@@ -127,6 +128,16 @@ class DiT(nn.Module):
127
  self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
128
  self.proj_out = nn.Linear(dim, mel_dim)
129
 
 
 
 
 
 
 
 
 
 
 
130
  def forward(
131
  self,
132
  x: float["b n d"], # nosied input audio # noqa: F722
@@ -152,7 +163,10 @@ class DiT(nn.Module):
152
  residual = x
153
 
154
  for block in self.transformer_blocks:
155
- x = block(x, t, mask=mask, rope=rope)
 
 
 
156
 
157
  if self.long_skip_connection is not None:
158
  x = self.long_skip_connection(torch.cat((x, residual), dim=-1))
 
105
  text_dim=None,
106
  conv_layers=0,
107
  long_skip_connection=False,
108
+ checkpoint_activations=False,
109
  ):
110
  super().__init__()
111
 
 
128
  self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
129
  self.proj_out = nn.Linear(dim, mel_dim)
130
 
131
+ self.checkpoint_activations = checkpoint_activations
132
+
133
+ def ckpt_wrapper(self, module):
134
+ # https://github.com/chuanyangjin/fast-DiT/blob/main/models.py
135
+ def ckpt_forward(*inputs):
136
+ outputs = module(*inputs)
137
+ return outputs
138
+
139
+ return ckpt_forward
140
+
141
  def forward(
142
  self,
143
  x: float["b n d"], # nosied input audio # noqa: F722
 
163
  residual = x
164
 
165
  for block in self.transformer_blocks:
166
+ if self.checkpoint_activations:
167
+ x = torch.utils.checkpoint.checkpoint(self.ckpt_wrapper(block), x, t, mask, rope)
168
+ else:
169
+ x = block(x, t, mask=mask, rope=rope)
170
 
171
  if self.long_skip_connection is not None:
172
  x = self.long_skip_connection(torch.cat((x, residual), dim=-1))
src/f5_tts/model/trainer.py CHANGED
@@ -315,7 +315,7 @@ class Trainer:
315
  self.scheduler.step()
316
  self.optimizer.zero_grad()
317
 
318
- if self.is_main:
319
  self.ema_model.update()
320
 
321
  global_step += 1
 
315
  self.scheduler.step()
316
  self.optimizer.zero_grad()
317
 
318
+ if self.is_main and self.accelerator.sync_gradients:
319
  self.ema_model.update()
320
 
321
  global_step += 1
src/f5_tts/model/utils.py CHANGED
@@ -133,16 +133,23 @@ def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
133
 
134
  # convert char to pinyin
135
 
 
 
 
136
 
137
  def convert_char_to_pinyin(text_list, polyphone=True):
138
  final_text_list = []
139
- god_knows_why_en_testset_contains_zh_quote = str.maketrans(
140
- {"“": '"', "”": '"', "‘": "'", "’": "'"}
141
- ) # in case librispeech (orig no-pc) test-clean
142
- custom_trans = str.maketrans({";": ","}) # add custom trans here, to address oov
 
 
 
 
 
143
  for text in text_list:
144
  char_list = []
145
- text = text.translate(god_knows_why_en_testset_contains_zh_quote)
146
  text = text.translate(custom_trans)
147
  for seg in jieba.cut(text):
148
  seg_byte_len = len(bytes(seg, "UTF-8"))
@@ -150,22 +157,21 @@ def convert_char_to_pinyin(text_list, polyphone=True):
150
  if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
151
  char_list.append(" ")
152
  char_list.extend(seg)
153
- elif polyphone and seg_byte_len == 3 * len(seg): # if pure chinese characters
154
- seg = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True)
155
- for c in seg:
156
- if c not in "。,、;:?!《》【】—…":
157
  char_list.append(" ")
158
- char_list.append(c)
159
- else: # if mixed chinese characters, alphabets and symbols
160
  for c in seg:
161
  if ord(c) < 256:
162
  char_list.extend(c)
 
 
 
163
  else:
164
- if c not in "。,、;:?!《》【】—…":
165
- char_list.append(" ")
166
- char_list.extend(lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True))
167
- else: # if is zh punc
168
- char_list.append(c)
169
  final_text_list.append(char_list)
170
 
171
  return final_text_list
 
133
 
134
  # convert char to pinyin
135
 
136
+ jieba.initialize()
137
+ print("Word segmentation module jieba initialized.\n")
138
+
139
 
140
  def convert_char_to_pinyin(text_list, polyphone=True):
141
  final_text_list = []
142
+ custom_trans = str.maketrans(
143
+ {";": ",", "“": '"', "”": '"', "‘": "'", "’": "'"}
144
+ ) # add custom trans here, to address oov
145
+
146
+ def is_chinese(c):
147
+ return (
148
+ "\u3100" <= c <= "\u9fff" # common chinese characters
149
+ )
150
+
151
  for text in text_list:
152
  char_list = []
 
153
  text = text.translate(custom_trans)
154
  for seg in jieba.cut(text):
155
  seg_byte_len = len(bytes(seg, "UTF-8"))
 
157
  if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
158
  char_list.append(" ")
159
  char_list.extend(seg)
160
+ elif polyphone and seg_byte_len == 3 * len(seg): # if pure east asian characters
161
+ seg_ = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True)
162
+ for i, c in enumerate(seg):
163
+ if is_chinese(c):
164
  char_list.append(" ")
165
+ char_list.append(seg_[i])
166
+ else: # if mixed characters, alphabets and symbols
167
  for c in seg:
168
  if ord(c) < 256:
169
  char_list.extend(c)
170
+ elif is_chinese(c):
171
+ char_list.append(" ")
172
+ char_list.extend(lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True))
173
  else:
174
+ char_list.append(c)
 
 
 
 
175
  final_text_list.append(char_list)
176
 
177
  return final_text_list
src/f5_tts/train/README.md CHANGED
@@ -2,9 +2,9 @@
2
 
3
  ## Prepare Dataset
4
 
5
- Example data processing scripts for Emilia and Wenetspeech4TTS, and you may tailor your own one along with a Dataset class in `src/f5_tts/model/dataset.py`.
6
 
7
- ### 1. Datasets used for pretrained models
8
  Download corresponding dataset first, and fill in the path in scripts.
9
 
10
  ```bash
@@ -16,6 +16,9 @@ python src/f5_tts/train/datasets/prepare_wenetspeech4tts.py
16
 
17
  # Prepare the LibriTTS dataset
18
  python src/f5_tts/train/datasets/prepare_libritts.py
 
 
 
19
  ```
20
 
21
  ### 2. Create custom dataset with metadata.csv
@@ -35,7 +38,12 @@ Once your datasets are prepared, you can start the training process.
35
  # setup accelerate config, e.g. use multi-gpu ddp, fp16
36
  # will be to: ~/.cache/huggingface/accelerate/default_config.yaml
37
  accelerate config
38
- accelerate launch src/f5_tts/train/train.py
 
 
 
 
 
39
  ```
40
 
41
  ### 2. Finetuning practice
@@ -43,6 +51,8 @@ Discussion board for Finetuning [#57](https://github.com/SWivid/F5-TTS/discussio
43
 
44
  Gradio UI training/finetuning with `src/f5_tts/train/finetune_gradio.py` see [#143](https://github.com/SWivid/F5-TTS/discussions/143).
45
 
 
 
46
  ### 3. Wandb Logging
47
 
48
  The `wandb/` dir will be created under path you run training/finetuning scripts.
 
2
 
3
  ## Prepare Dataset
4
 
5
+ Example data processing scripts, and you may tailor your own one along with a Dataset class in `src/f5_tts/model/dataset.py`.
6
 
7
+ ### 1. Some specific Datasets preparing scripts
8
  Download corresponding dataset first, and fill in the path in scripts.
9
 
10
  ```bash
 
16
 
17
  # Prepare the LibriTTS dataset
18
  python src/f5_tts/train/datasets/prepare_libritts.py
19
+
20
+ # Prepare the LJSpeech dataset
21
+ python src/f5_tts/train/datasets/prepare_ljspeech.py
22
  ```
23
 
24
  ### 2. Create custom dataset with metadata.csv
 
38
  # setup accelerate config, e.g. use multi-gpu ddp, fp16
39
  # will be to: ~/.cache/huggingface/accelerate/default_config.yaml
40
  accelerate config
41
+
42
+ # .yaml files are under src/f5_tts/configs directory
43
+ accelerate launch src/f5_tts/train/train.py --config-name F5TTS_Base_train.yaml
44
+
45
+ # possible to overwrite accelerate and hydra config
46
+ accelerate launch --mixed_precision=fp16 src/f5_tts/train/train.py --config-name F5TTS_Small_train.yaml ++datasets.batch_size_per_gpu=19200
47
  ```
48
 
49
  ### 2. Finetuning practice
 
51
 
52
  Gradio UI training/finetuning with `src/f5_tts/train/finetune_gradio.py` see [#143](https://github.com/SWivid/F5-TTS/discussions/143).
53
 
54
+ The `use_ema = True` is harmful for early-stage finetuned checkpoints (which goes just few updates, thus ema weights still dominated by pretrained ones), try turn it off and see if provide better results.
55
+
56
  ### 3. Wandb Logging
57
 
58
  The `wandb/` dir will be created under path you run training/finetuning scripts.
src/f5_tts/train/datasets/prepare_ljspeech.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ sys.path.append(os.getcwd())
5
+
6
+ import json
7
+ from importlib.resources import files
8
+ from pathlib import Path
9
+ from tqdm import tqdm
10
+ import soundfile as sf
11
+ from datasets.arrow_writer import ArrowWriter
12
+
13
+
14
+ def main():
15
+ result = []
16
+ duration_list = []
17
+ text_vocab_set = set()
18
+
19
+ with open(meta_info, "r") as f:
20
+ lines = f.readlines()
21
+ for line in tqdm(lines):
22
+ uttr, text, norm_text = line.split("|")
23
+ norm_text = norm_text.strip()
24
+ wav_path = Path(dataset_dir) / "wavs" / f"{uttr}.wav"
25
+ duration = sf.info(wav_path).duration
26
+ if duration < 0.4 or duration > 30:
27
+ continue
28
+ result.append({"audio_path": str(wav_path), "text": norm_text, "duration": duration})
29
+ duration_list.append(duration)
30
+ text_vocab_set.update(list(norm_text))
31
+
32
+ # save preprocessed dataset to disk
33
+ if not os.path.exists(f"{save_dir}"):
34
+ os.makedirs(f"{save_dir}")
35
+ print(f"\nSaving to {save_dir} ...")
36
+
37
+ with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer:
38
+ for line in tqdm(result, desc="Writing to raw.arrow ..."):
39
+ writer.write(line)
40
+
41
+ # dup a json separately saving duration in case for DynamicBatchSampler ease
42
+ with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f:
43
+ json.dump({"duration": duration_list}, f, ensure_ascii=False)
44
+
45
+ # vocab map, i.e. tokenizer
46
+ # add alphabets and symbols (optional, if plan to ft on de/fr etc.)
47
+ with open(f"{save_dir}/vocab.txt", "w") as f:
48
+ for vocab in sorted(text_vocab_set):
49
+ f.write(vocab + "\n")
50
+
51
+ print(f"\nFor {dataset_name}, sample count: {len(result)}")
52
+ print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
53
+ print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours")
54
+
55
+
56
+ if __name__ == "__main__":
57
+ tokenizer = "char" # "pinyin" | "char"
58
+
59
+ dataset_dir = "<SOME_PATH>/LJSpeech-1.1"
60
+ dataset_name = f"LJSpeech_{tokenizer}"
61
+ meta_info = os.path.join(dataset_dir, "metadata.csv")
62
+ save_dir = str(files("f5_tts").joinpath("../../")) + f"/data/{dataset_name}"
63
+ print(f"\nPrepare for {dataset_name}, will save to {save_dir}\n")
64
+
65
+ main()
src/f5_tts/train/train.py CHANGED
@@ -1,100 +1,72 @@
1
  # training script.
2
 
 
3
  from importlib.resources import files
4
 
 
 
5
  from f5_tts.model import CFM, DiT, Trainer, UNetT
6
  from f5_tts.model.dataset import load_dataset
7
  from f5_tts.model.utils import get_tokenizer
8
 
9
- # -------------------------- Dataset Settings --------------------------- #
10
-
11
- target_sample_rate = 24000
12
- n_mel_channels = 100
13
- hop_length = 256
14
- win_length = 1024
15
- n_fft = 1024
16
- mel_spec_type = "vocos" # 'vocos' or 'bigvgan'
17
-
18
- tokenizer = "pinyin" # 'pinyin', 'char', or 'custom'
19
- tokenizer_path = None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
20
- dataset_name = "Emilia_ZH_EN"
21
-
22
- # -------------------------- Training Settings -------------------------- #
23
 
24
- exp_name = "F5TTS_Base" # F5TTS_Base | E2TTS_Base
25
 
26
- learning_rate = 7.5e-5
 
 
 
 
27
 
28
- batch_size_per_gpu = 38400 # 8 GPUs, 8 * 38400 = 307200
29
- batch_size_type = "frame" # "frame" or "sample"
30
- max_samples = 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
31
- grad_accumulation_steps = 1 # note: updates = steps / grad_accumulation_steps
32
- max_grad_norm = 1.0
33
-
34
- epochs = 11 # use linear decay, thus epochs control the slope
35
- num_warmup_updates = 20000 # warmup steps
36
- save_per_updates = 50000 # save checkpoint per steps
37
- last_per_steps = 5000 # save last checkpoint per steps
38
-
39
- # model params
40
- if exp_name == "F5TTS_Base":
41
- wandb_resume_id = None
42
- model_cls = DiT
43
- model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
44
- elif exp_name == "E2TTS_Base":
45
- wandb_resume_id = None
46
- model_cls = UNetT
47
- model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
48
-
49
-
50
- # ----------------------------------------------------------------------- #
51
-
52
-
53
- def main():
54
- if tokenizer == "custom":
55
- tokenizer_path = tokenizer_path
56
  else:
57
- tokenizer_path = dataset_name
58
  vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
59
 
60
- mel_spec_kwargs = dict(
61
- n_fft=n_fft,
62
- hop_length=hop_length,
63
- win_length=win_length,
64
- n_mel_channels=n_mel_channels,
65
- target_sample_rate=target_sample_rate,
66
- mel_spec_type=mel_spec_type,
67
- )
68
 
69
  model = CFM(
70
- transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
71
- mel_spec_kwargs=mel_spec_kwargs,
72
  vocab_char_map=vocab_char_map,
73
  )
74
 
 
75
  trainer = Trainer(
76
  model,
77
- epochs,
78
- learning_rate,
79
- num_warmup_updates=num_warmup_updates,
80
- save_per_updates=save_per_updates,
81
- checkpoint_path=str(files("f5_tts").joinpath(f"../../ckpts/{exp_name}")),
82
- batch_size=batch_size_per_gpu,
83
- batch_size_type=batch_size_type,
84
- max_samples=max_samples,
85
- grad_accumulation_steps=grad_accumulation_steps,
86
- max_grad_norm=max_grad_norm,
 
87
  wandb_project="CFM-TTS",
88
  wandb_run_name=exp_name,
89
  wandb_resume_id=wandb_resume_id,
90
- last_per_steps=last_per_steps,
91
  log_samples=True,
 
92
  mel_spec_type=mel_spec_type,
 
 
93
  )
94
 
95
- train_dataset = load_dataset(dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)
96
  trainer.train(
97
  train_dataset,
 
98
  resumable_with_seed=666, # seed for shuffling dataset
99
  )
100
 
 
1
  # training script.
2
 
3
+ import os
4
  from importlib.resources import files
5
 
6
+ import hydra
7
+
8
  from f5_tts.model import CFM, DiT, Trainer, UNetT
9
  from f5_tts.model.dataset import load_dataset
10
  from f5_tts.model.utils import get_tokenizer
11
 
12
+ os.chdir(str(files("f5_tts").joinpath("../.."))) # change working directory to root of project (local editable)
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
 
14
 
15
+ @hydra.main(version_base="1.3", config_path=str(files("f5_tts").joinpath("configs")), config_name=None)
16
+ def main(cfg):
17
+ tokenizer = cfg.model.tokenizer
18
+ mel_spec_type = cfg.model.mel_spec.mel_spec_type
19
+ exp_name = f"{cfg.model.name}_{mel_spec_type}_{cfg.model.tokenizer}_{cfg.datasets.name}"
20
 
21
+ # set text tokenizer
22
+ if tokenizer != "custom":
23
+ tokenizer_path = cfg.datasets.name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  else:
25
+ tokenizer_path = cfg.model.tokenizer_path
26
  vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
27
 
28
+ # set model
29
+ if "F5TTS" in cfg.model.name:
30
+ model_cls = DiT
31
+ elif "E2TTS" in cfg.model.name:
32
+ model_cls = UNetT
33
+ wandb_resume_id = None
 
 
34
 
35
  model = CFM(
36
+ transformer=model_cls(**cfg.model.arch, text_num_embeds=vocab_size, mel_dim=cfg.model.mel_spec.n_mel_channels),
37
+ mel_spec_kwargs=cfg.model.mel_spec,
38
  vocab_char_map=vocab_char_map,
39
  )
40
 
41
+ # init trainer
42
  trainer = Trainer(
43
  model,
44
+ epochs=cfg.optim.epochs,
45
+ learning_rate=cfg.optim.learning_rate,
46
+ num_warmup_updates=cfg.optim.num_warmup_updates,
47
+ save_per_updates=cfg.ckpts.save_per_updates,
48
+ checkpoint_path=str(files("f5_tts").joinpath(f"../../{cfg.ckpts.save_dir}")),
49
+ batch_size=cfg.datasets.batch_size_per_gpu,
50
+ batch_size_type=cfg.datasets.batch_size_type,
51
+ max_samples=cfg.datasets.max_samples,
52
+ grad_accumulation_steps=cfg.optim.grad_accumulation_steps,
53
+ max_grad_norm=cfg.optim.max_grad_norm,
54
+ logger=cfg.ckpts.logger,
55
  wandb_project="CFM-TTS",
56
  wandb_run_name=exp_name,
57
  wandb_resume_id=wandb_resume_id,
58
+ last_per_steps=cfg.ckpts.last_per_steps,
59
  log_samples=True,
60
+ bnb_optimizer=cfg.optim.bnb_optimizer,
61
  mel_spec_type=mel_spec_type,
62
+ is_local_vocoder=cfg.model.vocoder.is_local,
63
+ local_vocoder_path=cfg.model.vocoder.local_path,
64
  )
65
 
66
+ train_dataset = load_dataset(cfg.datasets.name, tokenizer, mel_spec_kwargs=cfg.model.mel_spec)
67
  trainer.train(
68
  train_dataset,
69
+ num_workers=cfg.datasets.num_workers,
70
  resumable_with_seed=666, # seed for shuffling dataset
71
  )
72