Spaces:
Sleeping
Sleeping
Merge branch #mrfakename/E2-F5-TTS' into 'VX3/MimicYouFree'
Browse files- README_REPO.md +3 -3
- app.py +115 -78
- pyproject.toml +2 -1
- src/f5_tts/configs/E2TTS_Base_train.yaml +44 -0
- src/f5_tts/configs/E2TTS_Small_train.yaml +44 -0
- src/f5_tts/configs/F5TTS_Base_train.yaml +47 -0
- src/f5_tts/configs/F5TTS_Small_train.yaml +47 -0
- src/f5_tts/eval/README.md +9 -6
- src/f5_tts/eval/eval_infer_batch.py +2 -2
- src/f5_tts/eval/eval_librispeech_test_clean.py +77 -54
- src/f5_tts/eval/eval_seedtts_testset.py +76 -56
- src/f5_tts/eval/eval_utmos.py +44 -0
- src/f5_tts/eval/utils_eval.py +16 -8
- src/f5_tts/infer/README.md +5 -0
- src/f5_tts/infer/SHARED.md +93 -21
- src/f5_tts/infer/examples/basic/basic.toml +1 -1
- src/f5_tts/infer/examples/multi/story.toml +1 -0
- src/f5_tts/infer/infer_cli.py +181 -51
- src/f5_tts/infer/utils_infer.py +12 -3
- src/f5_tts/model/backbones/dit.py +15 -1
- src/f5_tts/model/trainer.py +1 -1
- src/f5_tts/model/utils.py +22 -16
- src/f5_tts/train/README.md +13 -3
- src/f5_tts/train/datasets/prepare_ljspeech.py +65 -0
- src/f5_tts/train/train.py +39 -67
README_REPO.md
CHANGED
@@ -147,11 +147,11 @@ Note: Some model components have linting exceptions for E722 to accommodate tens
|
|
147 |
## Acknowledgements
|
148 |
|
149 |
- [E2-TTS](https://arxiv.org/abs/2406.18009) brilliant work, simple and effective
|
150 |
-
- [Emilia](https://arxiv.org/abs/2407.05361), [WenetSpeech4TTS](https://arxiv.org/abs/2406.05763) valuable datasets
|
151 |
- [lucidrains](https://github.com/lucidrains) initial CFM structure with also [bfs18](https://github.com/bfs18) for discussion
|
152 |
- [SD3](https://arxiv.org/abs/2403.03206) & [Hugging Face diffusers](https://github.com/huggingface/diffusers) DiT and MMDiT code structure
|
153 |
-
- [torchdiffeq](https://github.com/rtqichen/torchdiffeq) as ODE solver, [Vocos](https://huggingface.co/charactr/vocos-mel-24khz) as vocoder
|
154 |
-
- [FunASR](https://github.com/modelscope/FunASR), [faster-whisper](https://github.com/SYSTRAN/faster-whisper), [UniSpeech](https://github.com/microsoft/UniSpeech) for evaluation tools
|
155 |
- [ctc-forced-aligner](https://github.com/MahmoudAshraf97/ctc-forced-aligner) for speech edit test
|
156 |
- [mrfakename](https://x.com/realmrfakename) huggingface space demo ~
|
157 |
- [f5-tts-mlx](https://github.com/lucasnewman/f5-tts-mlx/tree/main) Implementation with MLX framework by [Lucas Newman](https://github.com/lucasnewman)
|
|
|
147 |
## Acknowledgements
|
148 |
|
149 |
- [E2-TTS](https://arxiv.org/abs/2406.18009) brilliant work, simple and effective
|
150 |
+
- [Emilia](https://arxiv.org/abs/2407.05361), [WenetSpeech4TTS](https://arxiv.org/abs/2406.05763), [LibriTTS](https://arxiv.org/abs/1904.02882), [LJSpeech](https://keithito.com/LJ-Speech-Dataset/) valuable datasets
|
151 |
- [lucidrains](https://github.com/lucidrains) initial CFM structure with also [bfs18](https://github.com/bfs18) for discussion
|
152 |
- [SD3](https://arxiv.org/abs/2403.03206) & [Hugging Face diffusers](https://github.com/huggingface/diffusers) DiT and MMDiT code structure
|
153 |
+
- [torchdiffeq](https://github.com/rtqichen/torchdiffeq) as ODE solver, [Vocos](https://huggingface.co/charactr/vocos-mel-24khz) and [BigVGAN](https://github.com/NVIDIA/BigVGAN) as vocoder
|
154 |
+
- [FunASR](https://github.com/modelscope/FunASR), [faster-whisper](https://github.com/SYSTRAN/faster-whisper), [UniSpeech](https://github.com/microsoft/UniSpeech), [SpeechMOS](https://github.com/tarepan/SpeechMOS) for evaluation tools
|
155 |
- [ctc-forced-aligner](https://github.com/MahmoudAshraf97/ctc-forced-aligner) for speech edit test
|
156 |
- [mrfakename](https://x.com/realmrfakename) huggingface space demo ~
|
157 |
- [f5-tts-mlx](https://github.com/lucasnewman/f5-tts-mlx/tree/main) Implementation with MLX framework by [Lucas Newman](https://github.com/lucasnewman)
|
app.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
# ruff: noqa: E402
|
2 |
# Above allows ruff to ignore E402: module level import not at top of file
|
3 |
|
|
|
4 |
import re
|
5 |
import tempfile
|
6 |
from collections import OrderedDict
|
@@ -43,6 +44,12 @@ from f5_tts.infer.utils_infer import (
|
|
43 |
DEFAULT_TTS_MODEL = "F5-TTS"
|
44 |
tts_model_choice = DEFAULT_TTS_MODEL
|
45 |
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
|
47 |
# load models
|
48 |
|
@@ -103,8 +110,24 @@ def generate_response(messages, model, tokenizer):
|
|
103 |
|
104 |
@gpu_decorator
|
105 |
def infer(
|
106 |
-
ref_audio_orig,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=show_info)
|
109 |
|
110 |
if model == "F5-TTS":
|
@@ -120,7 +143,7 @@ def infer(
|
|
120 |
global custom_ema_model, pre_custom_path
|
121 |
if pre_custom_path != model[1]:
|
122 |
show_info("Loading Custom TTS model...")
|
123 |
-
custom_ema_model = load_custom(model[1], vocab_path=model[2])
|
124 |
pre_custom_path = model[1]
|
125 |
ema_model = custom_ema_model
|
126 |
|
@@ -131,6 +154,7 @@ def infer(
|
|
131 |
ema_model,
|
132 |
vocoder,
|
133 |
cross_fade_duration=cross_fade_duration,
|
|
|
134 |
speed=speed,
|
135 |
show_info=show_info,
|
136 |
progress=gr.Progress(),
|
@@ -184,6 +208,14 @@ with gr.Blocks() as app_tts:
|
|
184 |
step=0.1,
|
185 |
info="Adjust the speed of the audio.",
|
186 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
187 |
cross_fade_duration_slider = gr.Slider(
|
188 |
label="Cross-Fade Duration (s)",
|
189 |
minimum=0.0,
|
@@ -203,6 +235,7 @@ with gr.Blocks() as app_tts:
|
|
203 |
gen_text_input,
|
204 |
remove_silence,
|
205 |
cross_fade_duration_slider,
|
|
|
206 |
speed_slider,
|
207 |
):
|
208 |
audio_out, spectrogram_path, ref_text_out = infer(
|
@@ -211,10 +244,11 @@ with gr.Blocks() as app_tts:
|
|
211 |
gen_text_input,
|
212 |
tts_model_choice,
|
213 |
remove_silence,
|
214 |
-
cross_fade_duration_slider,
|
215 |
-
|
|
|
216 |
)
|
217 |
-
return audio_out, spectrogram_path,
|
218 |
|
219 |
generate_btn.click(
|
220 |
basic_tts,
|
@@ -224,6 +258,7 @@ with gr.Blocks() as app_tts:
|
|
224 |
gen_text_input,
|
225 |
remove_silence,
|
226 |
cross_fade_duration_slider,
|
|
|
227 |
speed_slider,
|
228 |
],
|
229 |
outputs=[audio_output, spectrogram_output, ref_text_input],
|
@@ -293,7 +328,7 @@ with gr.Blocks() as app_multistyle:
|
|
293 |
)
|
294 |
|
295 |
# Regular speech type (mandatory)
|
296 |
-
with gr.Row():
|
297 |
with gr.Column():
|
298 |
regular_name = gr.Textbox(value="Regular", label="Speech Type Name")
|
299 |
regular_insert = gr.Button("Insert Label", variant="secondary")
|
@@ -302,12 +337,12 @@ with gr.Blocks() as app_multistyle:
|
|
302 |
|
303 |
# Regular speech type (max 100)
|
304 |
max_speech_types = 100
|
305 |
-
speech_type_rows = []
|
306 |
-
speech_type_names = [regular_name]
|
307 |
-
speech_type_audios = [regular_audio]
|
308 |
-
speech_type_ref_texts = [regular_ref_text]
|
309 |
-
speech_type_delete_btns = []
|
310 |
-
speech_type_insert_btns = [regular_insert]
|
311 |
|
312 |
# Additional speech types (99 more)
|
313 |
for i in range(max_speech_types - 1):
|
@@ -328,51 +363,32 @@ with gr.Blocks() as app_multistyle:
|
|
328 |
# Button to add speech type
|
329 |
add_speech_type_btn = gr.Button("Add Speech Type")
|
330 |
|
331 |
-
# Keep track of
|
332 |
-
speech_type_count =
|
333 |
|
334 |
# Function to add a speech type
|
335 |
-
def add_speech_type_fn(
|
|
|
|
|
336 |
if speech_type_count < max_speech_types:
|
|
|
337 |
speech_type_count += 1
|
338 |
-
# Prepare updates for the rows
|
339 |
-
row_updates = []
|
340 |
-
for i in range(1, max_speech_types):
|
341 |
-
if i < speech_type_count:
|
342 |
-
row_updates.append(gr.update(visible=True))
|
343 |
-
else:
|
344 |
-
row_updates.append(gr.update())
|
345 |
else:
|
346 |
-
|
347 |
-
|
348 |
-
return [speech_type_count] + row_updates
|
349 |
|
350 |
-
add_speech_type_btn.click(
|
351 |
-
add_speech_type_fn, inputs=speech_type_count, outputs=[speech_type_count] + speech_type_rows
|
352 |
-
)
|
353 |
|
354 |
# Function to delete a speech type
|
355 |
-
def
|
356 |
-
|
357 |
-
# Prepare updates
|
358 |
-
row_updates = []
|
359 |
-
|
360 |
-
for i in range(1, max_speech_types):
|
361 |
-
if i == index:
|
362 |
-
row_updates.append(gr.update(visible=False))
|
363 |
-
else:
|
364 |
-
row_updates.append(gr.update())
|
365 |
-
|
366 |
-
speech_type_count = max(1, speech_type_count)
|
367 |
-
|
368 |
-
return [speech_type_count] + row_updates
|
369 |
-
|
370 |
-
return delete_speech_type_fn
|
371 |
|
372 |
# Update delete button clicks
|
373 |
-
for i
|
374 |
-
|
375 |
-
|
|
|
|
|
376 |
|
377 |
# Text input for the prompt
|
378 |
gen_text_input_multistyle = gr.Textbox(
|
@@ -386,7 +402,7 @@ with gr.Blocks() as app_multistyle:
|
|
386 |
current_text = current_text or ""
|
387 |
speech_type_name = speech_type_name or "None"
|
388 |
updated_text = current_text + f"{{{speech_type_name}}} "
|
389 |
-
return
|
390 |
|
391 |
return insert_speech_type_fn
|
392 |
|
@@ -446,10 +462,14 @@ with gr.Blocks() as app_multistyle:
|
|
446 |
if style in speech_types:
|
447 |
current_style = style
|
448 |
else:
|
449 |
-
|
450 |
current_style = "Regular"
|
451 |
|
452 |
-
|
|
|
|
|
|
|
|
|
453 |
ref_text = speech_types[current_style].get("ref_text", "")
|
454 |
|
455 |
# Generate speech for this segment
|
@@ -464,12 +484,10 @@ with gr.Blocks() as app_multistyle:
|
|
464 |
# Concatenate all audio segments
|
465 |
if generated_audio_segments:
|
466 |
final_audio_data = np.concatenate(generated_audio_segments)
|
467 |
-
return [(sr, final_audio_data)] + [
|
468 |
-
gr.update(value=speech_types[style]["ref_text"]) for style in speech_types
|
469 |
-
]
|
470 |
else:
|
471 |
gr.Warning("No audio generated.")
|
472 |
-
return [None] + [
|
473 |
|
474 |
generate_multistyle_btn.click(
|
475 |
generate_multistyle_speech,
|
@@ -487,7 +505,7 @@ with gr.Blocks() as app_multistyle:
|
|
487 |
|
488 |
# Validation function to disable Generate button if speech types are missing
|
489 |
def validate_speech_types(gen_text, regular_name, *args):
|
490 |
-
speech_type_names_list = args
|
491 |
|
492 |
# Collect the speech types names
|
493 |
speech_types_available = set()
|
@@ -651,7 +669,7 @@ Have a conversation with an AI using your reference voice!
|
|
651 |
speed=1.0,
|
652 |
show_info=print, # show_info=print no pull to top when generating
|
653 |
)
|
654 |
-
return audio_result,
|
655 |
|
656 |
def clear_conversation():
|
657 |
"""Reset the conversation"""
|
@@ -744,34 +762,38 @@ If you're having issues, try converting your reference audio to WAV or MP3, clip
|
|
744 |
"""
|
745 |
)
|
746 |
|
747 |
-
last_used_custom = files("f5_tts").joinpath("infer/.cache/
|
748 |
|
749 |
def load_last_used_custom():
|
750 |
try:
|
751 |
-
|
752 |
-
|
|
|
|
|
|
|
753 |
except FileNotFoundError:
|
754 |
last_used_custom.parent.mkdir(parents=True, exist_ok=True)
|
755 |
-
return
|
756 |
-
"hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors",
|
757 |
-
"hf://SWivid/F5-TTS/F5TTS_Base/vocab.txt",
|
758 |
-
]
|
759 |
|
760 |
def switch_tts_model(new_choice):
|
761 |
global tts_model_choice
|
762 |
if new_choice == "Custom": # override in case webpage is refreshed
|
763 |
-
custom_ckpt_path, custom_vocab_path = load_last_used_custom()
|
764 |
-
tts_model_choice = ["Custom", custom_ckpt_path, custom_vocab_path]
|
765 |
-
return
|
|
|
|
|
|
|
|
|
766 |
else:
|
767 |
tts_model_choice = new_choice
|
768 |
-
return gr.update(visible=False), gr.update(visible=False)
|
769 |
|
770 |
-
def set_custom_model(custom_ckpt_path, custom_vocab_path):
|
771 |
global tts_model_choice
|
772 |
-
tts_model_choice = ["Custom", custom_ckpt_path, custom_vocab_path]
|
773 |
-
with open(last_used_custom, "w") as f:
|
774 |
-
f.write(
|
775 |
|
776 |
with gr.Row():
|
777 |
if not USING_SPACES:
|
@@ -783,34 +805,49 @@ If you're having issues, try converting your reference audio to WAV or MP3, clip
|
|
783 |
choices=[DEFAULT_TTS_MODEL, "E2-TTS"], label="Choose TTS Model", value=DEFAULT_TTS_MODEL
|
784 |
)
|
785 |
custom_ckpt_path = gr.Dropdown(
|
786 |
-
choices=[
|
787 |
value=load_last_used_custom()[0],
|
788 |
allow_custom_value=True,
|
789 |
-
label="
|
790 |
visible=False,
|
791 |
)
|
792 |
custom_vocab_path = gr.Dropdown(
|
793 |
-
choices=[
|
794 |
value=load_last_used_custom()[1],
|
795 |
allow_custom_value=True,
|
796 |
-
label="
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
797 |
visible=False,
|
798 |
)
|
799 |
|
800 |
choose_tts_model.change(
|
801 |
switch_tts_model,
|
802 |
inputs=[choose_tts_model],
|
803 |
-
outputs=[custom_ckpt_path, custom_vocab_path],
|
804 |
show_progress="hidden",
|
805 |
)
|
806 |
custom_ckpt_path.change(
|
807 |
set_custom_model,
|
808 |
-
inputs=[custom_ckpt_path, custom_vocab_path],
|
809 |
show_progress="hidden",
|
810 |
)
|
811 |
custom_vocab_path.change(
|
812 |
set_custom_model,
|
813 |
-
inputs=[custom_ckpt_path, custom_vocab_path],
|
|
|
|
|
|
|
|
|
|
|
814 |
show_progress="hidden",
|
815 |
)
|
816 |
|
|
|
1 |
# ruff: noqa: E402
|
2 |
# Above allows ruff to ignore E402: module level import not at top of file
|
3 |
|
4 |
+
import json
|
5 |
import re
|
6 |
import tempfile
|
7 |
from collections import OrderedDict
|
|
|
44 |
DEFAULT_TTS_MODEL = "F5-TTS"
|
45 |
tts_model_choice = DEFAULT_TTS_MODEL
|
46 |
|
47 |
+
DEFAULT_TTS_MODEL_CFG = [
|
48 |
+
"hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors",
|
49 |
+
"hf://SWivid/F5-TTS/F5TTS_Base/vocab.txt",
|
50 |
+
json.dumps(dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)),
|
51 |
+
]
|
52 |
+
|
53 |
|
54 |
# load models
|
55 |
|
|
|
110 |
|
111 |
@gpu_decorator
|
112 |
def infer(
|
113 |
+
ref_audio_orig,
|
114 |
+
ref_text,
|
115 |
+
gen_text,
|
116 |
+
model,
|
117 |
+
remove_silence,
|
118 |
+
cross_fade_duration=0.15,
|
119 |
+
nfe_step=32,
|
120 |
+
speed=1,
|
121 |
+
show_info=gr.Info,
|
122 |
):
|
123 |
+
if not ref_audio_orig:
|
124 |
+
gr.Warning("Please provide reference audio.")
|
125 |
+
return gr.update(), gr.update(), ref_text
|
126 |
+
|
127 |
+
if not gen_text.strip():
|
128 |
+
gr.Warning("Please enter text to generate.")
|
129 |
+
return gr.update(), gr.update(), ref_text
|
130 |
+
|
131 |
ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=show_info)
|
132 |
|
133 |
if model == "F5-TTS":
|
|
|
143 |
global custom_ema_model, pre_custom_path
|
144 |
if pre_custom_path != model[1]:
|
145 |
show_info("Loading Custom TTS model...")
|
146 |
+
custom_ema_model = load_custom(model[1], vocab_path=model[2], model_cfg=model[3])
|
147 |
pre_custom_path = model[1]
|
148 |
ema_model = custom_ema_model
|
149 |
|
|
|
154 |
ema_model,
|
155 |
vocoder,
|
156 |
cross_fade_duration=cross_fade_duration,
|
157 |
+
nfe_step=nfe_step,
|
158 |
speed=speed,
|
159 |
show_info=show_info,
|
160 |
progress=gr.Progress(),
|
|
|
208 |
step=0.1,
|
209 |
info="Adjust the speed of the audio.",
|
210 |
)
|
211 |
+
nfe_slider = gr.Slider(
|
212 |
+
label="NFE Steps",
|
213 |
+
minimum=4,
|
214 |
+
maximum=64,
|
215 |
+
value=32,
|
216 |
+
step=2,
|
217 |
+
info="Set the number of denoising steps.",
|
218 |
+
)
|
219 |
cross_fade_duration_slider = gr.Slider(
|
220 |
label="Cross-Fade Duration (s)",
|
221 |
minimum=0.0,
|
|
|
235 |
gen_text_input,
|
236 |
remove_silence,
|
237 |
cross_fade_duration_slider,
|
238 |
+
nfe_slider,
|
239 |
speed_slider,
|
240 |
):
|
241 |
audio_out, spectrogram_path, ref_text_out = infer(
|
|
|
244 |
gen_text_input,
|
245 |
tts_model_choice,
|
246 |
remove_silence,
|
247 |
+
cross_fade_duration=cross_fade_duration_slider,
|
248 |
+
nfe_step=nfe_slider,
|
249 |
+
speed=speed_slider,
|
250 |
)
|
251 |
+
return audio_out, spectrogram_path, ref_text_out
|
252 |
|
253 |
generate_btn.click(
|
254 |
basic_tts,
|
|
|
258 |
gen_text_input,
|
259 |
remove_silence,
|
260 |
cross_fade_duration_slider,
|
261 |
+
nfe_slider,
|
262 |
speed_slider,
|
263 |
],
|
264 |
outputs=[audio_output, spectrogram_output, ref_text_input],
|
|
|
328 |
)
|
329 |
|
330 |
# Regular speech type (mandatory)
|
331 |
+
with gr.Row() as regular_row:
|
332 |
with gr.Column():
|
333 |
regular_name = gr.Textbox(value="Regular", label="Speech Type Name")
|
334 |
regular_insert = gr.Button("Insert Label", variant="secondary")
|
|
|
337 |
|
338 |
# Regular speech type (max 100)
|
339 |
max_speech_types = 100
|
340 |
+
speech_type_rows = [regular_row]
|
341 |
+
speech_type_names = [regular_name]
|
342 |
+
speech_type_audios = [regular_audio]
|
343 |
+
speech_type_ref_texts = [regular_ref_text]
|
344 |
+
speech_type_delete_btns = [None]
|
345 |
+
speech_type_insert_btns = [regular_insert]
|
346 |
|
347 |
# Additional speech types (99 more)
|
348 |
for i in range(max_speech_types - 1):
|
|
|
363 |
# Button to add speech type
|
364 |
add_speech_type_btn = gr.Button("Add Speech Type")
|
365 |
|
366 |
+
# Keep track of autoincrement of speech types, no roll back
|
367 |
+
speech_type_count = 1
|
368 |
|
369 |
# Function to add a speech type
|
370 |
+
def add_speech_type_fn():
|
371 |
+
row_updates = [gr.update() for _ in range(max_speech_types)]
|
372 |
+
global speech_type_count
|
373 |
if speech_type_count < max_speech_types:
|
374 |
+
row_updates[speech_type_count] = gr.update(visible=True)
|
375 |
speech_type_count += 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
376 |
else:
|
377 |
+
gr.Warning("Exhausted maximum number of speech types. Consider restart the app.")
|
378 |
+
return row_updates
|
|
|
379 |
|
380 |
+
add_speech_type_btn.click(add_speech_type_fn, outputs=speech_type_rows)
|
|
|
|
|
381 |
|
382 |
# Function to delete a speech type
|
383 |
+
def delete_speech_type_fn():
|
384 |
+
return gr.update(visible=False), None, None, None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
385 |
|
386 |
# Update delete button clicks
|
387 |
+
for i in range(1, len(speech_type_delete_btns)):
|
388 |
+
speech_type_delete_btns[i].click(
|
389 |
+
delete_speech_type_fn,
|
390 |
+
outputs=[speech_type_rows[i], speech_type_names[i], speech_type_audios[i], speech_type_ref_texts[i]],
|
391 |
+
)
|
392 |
|
393 |
# Text input for the prompt
|
394 |
gen_text_input_multistyle = gr.Textbox(
|
|
|
402 |
current_text = current_text or ""
|
403 |
speech_type_name = speech_type_name or "None"
|
404 |
updated_text = current_text + f"{{{speech_type_name}}} "
|
405 |
+
return updated_text
|
406 |
|
407 |
return insert_speech_type_fn
|
408 |
|
|
|
462 |
if style in speech_types:
|
463 |
current_style = style
|
464 |
else:
|
465 |
+
gr.Warning(f"Type {style} is not available, will use Regular as default.")
|
466 |
current_style = "Regular"
|
467 |
|
468 |
+
try:
|
469 |
+
ref_audio = speech_types[current_style]["audio"]
|
470 |
+
except KeyError:
|
471 |
+
gr.Warning(f"Please provide reference audio for type {current_style}.")
|
472 |
+
return [None] + [speech_types[style]["ref_text"] for style in speech_types]
|
473 |
ref_text = speech_types[current_style].get("ref_text", "")
|
474 |
|
475 |
# Generate speech for this segment
|
|
|
484 |
# Concatenate all audio segments
|
485 |
if generated_audio_segments:
|
486 |
final_audio_data = np.concatenate(generated_audio_segments)
|
487 |
+
return [(sr, final_audio_data)] + [speech_types[style]["ref_text"] for style in speech_types]
|
|
|
|
|
488 |
else:
|
489 |
gr.Warning("No audio generated.")
|
490 |
+
return [None] + [speech_types[style]["ref_text"] for style in speech_types]
|
491 |
|
492 |
generate_multistyle_btn.click(
|
493 |
generate_multistyle_speech,
|
|
|
505 |
|
506 |
# Validation function to disable Generate button if speech types are missing
|
507 |
def validate_speech_types(gen_text, regular_name, *args):
|
508 |
+
speech_type_names_list = args
|
509 |
|
510 |
# Collect the speech types names
|
511 |
speech_types_available = set()
|
|
|
669 |
speed=1.0,
|
670 |
show_info=print, # show_info=print no pull to top when generating
|
671 |
)
|
672 |
+
return audio_result, ref_text_out
|
673 |
|
674 |
def clear_conversation():
|
675 |
"""Reset the conversation"""
|
|
|
762 |
"""
|
763 |
)
|
764 |
|
765 |
+
last_used_custom = files("f5_tts").joinpath("infer/.cache/last_used_custom_model_info.txt")
|
766 |
|
767 |
def load_last_used_custom():
|
768 |
try:
|
769 |
+
custom = []
|
770 |
+
with open(last_used_custom, "r", encoding="utf-8") as f:
|
771 |
+
for line in f:
|
772 |
+
custom.append(line.strip())
|
773 |
+
return custom
|
774 |
except FileNotFoundError:
|
775 |
last_used_custom.parent.mkdir(parents=True, exist_ok=True)
|
776 |
+
return DEFAULT_TTS_MODEL_CFG
|
|
|
|
|
|
|
777 |
|
778 |
def switch_tts_model(new_choice):
|
779 |
global tts_model_choice
|
780 |
if new_choice == "Custom": # override in case webpage is refreshed
|
781 |
+
custom_ckpt_path, custom_vocab_path, custom_model_cfg = load_last_used_custom()
|
782 |
+
tts_model_choice = ["Custom", custom_ckpt_path, custom_vocab_path, json.loads(custom_model_cfg)]
|
783 |
+
return (
|
784 |
+
gr.update(visible=True, value=custom_ckpt_path),
|
785 |
+
gr.update(visible=True, value=custom_vocab_path),
|
786 |
+
gr.update(visible=True, value=custom_model_cfg),
|
787 |
+
)
|
788 |
else:
|
789 |
tts_model_choice = new_choice
|
790 |
+
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
|
791 |
|
792 |
+
def set_custom_model(custom_ckpt_path, custom_vocab_path, custom_model_cfg):
|
793 |
global tts_model_choice
|
794 |
+
tts_model_choice = ["Custom", custom_ckpt_path, custom_vocab_path, json.loads(custom_model_cfg)]
|
795 |
+
with open(last_used_custom, "w", encoding="utf-8") as f:
|
796 |
+
f.write(custom_ckpt_path + "\n" + custom_vocab_path + "\n" + custom_model_cfg + "\n")
|
797 |
|
798 |
with gr.Row():
|
799 |
if not USING_SPACES:
|
|
|
805 |
choices=[DEFAULT_TTS_MODEL, "E2-TTS"], label="Choose TTS Model", value=DEFAULT_TTS_MODEL
|
806 |
)
|
807 |
custom_ckpt_path = gr.Dropdown(
|
808 |
+
choices=[DEFAULT_TTS_MODEL_CFG[0]],
|
809 |
value=load_last_used_custom()[0],
|
810 |
allow_custom_value=True,
|
811 |
+
label="Model: local_path | hf://user_id/repo_id/model_ckpt",
|
812 |
visible=False,
|
813 |
)
|
814 |
custom_vocab_path = gr.Dropdown(
|
815 |
+
choices=[DEFAULT_TTS_MODEL_CFG[1]],
|
816 |
value=load_last_used_custom()[1],
|
817 |
allow_custom_value=True,
|
818 |
+
label="Vocab: local_path | hf://user_id/repo_id/vocab_file",
|
819 |
+
visible=False,
|
820 |
+
)
|
821 |
+
custom_model_cfg = gr.Dropdown(
|
822 |
+
choices=[
|
823 |
+
DEFAULT_TTS_MODEL_CFG[2],
|
824 |
+
json.dumps(dict(dim=768, depth=18, heads=12, ff_mult=2, text_dim=512, conv_layers=4)),
|
825 |
+
],
|
826 |
+
value=load_last_used_custom()[2],
|
827 |
+
allow_custom_value=True,
|
828 |
+
label="Config: in a dictionary form",
|
829 |
visible=False,
|
830 |
)
|
831 |
|
832 |
choose_tts_model.change(
|
833 |
switch_tts_model,
|
834 |
inputs=[choose_tts_model],
|
835 |
+
outputs=[custom_ckpt_path, custom_vocab_path, custom_model_cfg],
|
836 |
show_progress="hidden",
|
837 |
)
|
838 |
custom_ckpt_path.change(
|
839 |
set_custom_model,
|
840 |
+
inputs=[custom_ckpt_path, custom_vocab_path, custom_model_cfg],
|
841 |
show_progress="hidden",
|
842 |
)
|
843 |
custom_vocab_path.change(
|
844 |
set_custom_model,
|
845 |
+
inputs=[custom_ckpt_path, custom_vocab_path, custom_model_cfg],
|
846 |
+
show_progress="hidden",
|
847 |
+
)
|
848 |
+
custom_model_cfg.change(
|
849 |
+
set_custom_model,
|
850 |
+
inputs=[custom_ckpt_path, custom_vocab_path, custom_model_cfg],
|
851 |
show_progress="hidden",
|
852 |
)
|
853 |
|
pyproject.toml
CHANGED
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
4 |
|
5 |
[project]
|
6 |
name = "f5-tts"
|
7 |
-
version = "0.1
|
8 |
description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching"
|
9 |
readme = "README.md"
|
10 |
license = {text = "MIT License"}
|
@@ -21,6 +21,7 @@ dependencies = [
|
|
21 |
"datasets",
|
22 |
"ema_pytorch>=0.5.2",
|
23 |
"gradio>=3.45.2",
|
|
|
24 |
"jieba",
|
25 |
"librosa",
|
26 |
"matplotlib",
|
|
|
4 |
|
5 |
[project]
|
6 |
name = "f5-tts"
|
7 |
+
version = "0.3.1"
|
8 |
description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching"
|
9 |
readme = "README.md"
|
10 |
license = {text = "MIT License"}
|
|
|
21 |
"datasets",
|
22 |
"ema_pytorch>=0.5.2",
|
23 |
"gradio>=3.45.2",
|
24 |
+
"hydra-core>=1.3.0",
|
25 |
"jieba",
|
26 |
"librosa",
|
27 |
"matplotlib",
|
src/f5_tts/configs/E2TTS_Base_train.yaml
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
hydra:
|
2 |
+
run:
|
3 |
+
dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
|
4 |
+
|
5 |
+
datasets:
|
6 |
+
name: Emilia_ZH_EN # dataset name
|
7 |
+
batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
|
8 |
+
batch_size_type: frame # "frame" or "sample"
|
9 |
+
max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
|
10 |
+
num_workers: 16
|
11 |
+
|
12 |
+
optim:
|
13 |
+
epochs: 15
|
14 |
+
learning_rate: 7.5e-5
|
15 |
+
num_warmup_updates: 20000 # warmup steps
|
16 |
+
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
|
17 |
+
max_grad_norm: 1.0 # gradient clipping
|
18 |
+
bnb_optimizer: False # use bnb 8bit AdamW optimizer or not
|
19 |
+
|
20 |
+
model:
|
21 |
+
name: E2TTS_Base
|
22 |
+
tokenizer: pinyin
|
23 |
+
tokenizer_path: None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
|
24 |
+
arch:
|
25 |
+
dim: 1024
|
26 |
+
depth: 24
|
27 |
+
heads: 16
|
28 |
+
ff_mult: 4
|
29 |
+
mel_spec:
|
30 |
+
target_sample_rate: 24000
|
31 |
+
n_mel_channels: 100
|
32 |
+
hop_length: 256
|
33 |
+
win_length: 1024
|
34 |
+
n_fft: 1024
|
35 |
+
mel_spec_type: vocos # 'vocos' or 'bigvgan'
|
36 |
+
vocoder:
|
37 |
+
is_local: False # use local offline ckpt or not
|
38 |
+
local_path: None # local vocoder path
|
39 |
+
|
40 |
+
ckpts:
|
41 |
+
logger: wandb # wandb | tensorboard | None
|
42 |
+
save_per_updates: 50000 # save checkpoint per steps
|
43 |
+
last_per_steps: 5000 # save last checkpoint per steps
|
44 |
+
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
|
src/f5_tts/configs/E2TTS_Small_train.yaml
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
hydra:
|
2 |
+
run:
|
3 |
+
dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
|
4 |
+
|
5 |
+
datasets:
|
6 |
+
name: Emilia_ZH_EN
|
7 |
+
batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
|
8 |
+
batch_size_type: frame # "frame" or "sample"
|
9 |
+
max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
|
10 |
+
num_workers: 16
|
11 |
+
|
12 |
+
optim:
|
13 |
+
epochs: 15
|
14 |
+
learning_rate: 7.5e-5
|
15 |
+
num_warmup_updates: 20000 # warmup steps
|
16 |
+
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
|
17 |
+
max_grad_norm: 1.0
|
18 |
+
bnb_optimizer: False
|
19 |
+
|
20 |
+
model:
|
21 |
+
name: E2TTS_Small
|
22 |
+
tokenizer: pinyin
|
23 |
+
tokenizer_path: None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
|
24 |
+
arch:
|
25 |
+
dim: 768
|
26 |
+
depth: 20
|
27 |
+
heads: 12
|
28 |
+
ff_mult: 4
|
29 |
+
mel_spec:
|
30 |
+
target_sample_rate: 24000
|
31 |
+
n_mel_channels: 100
|
32 |
+
hop_length: 256
|
33 |
+
win_length: 1024
|
34 |
+
n_fft: 1024
|
35 |
+
mel_spec_type: vocos # 'vocos' or 'bigvgan'
|
36 |
+
vocoder:
|
37 |
+
is_local: False # use local offline ckpt or not
|
38 |
+
local_path: None # local vocoder path
|
39 |
+
|
40 |
+
ckpts:
|
41 |
+
logger: wandb # wandb | tensorboard | None
|
42 |
+
save_per_updates: 50000 # save checkpoint per steps
|
43 |
+
last_per_steps: 5000 # save last checkpoint per steps
|
44 |
+
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
|
src/f5_tts/configs/F5TTS_Base_train.yaml
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
hydra:
|
2 |
+
run:
|
3 |
+
dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
|
4 |
+
|
5 |
+
datasets:
|
6 |
+
name: Emilia_ZH_EN # dataset name
|
7 |
+
batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
|
8 |
+
batch_size_type: frame # "frame" or "sample"
|
9 |
+
max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
|
10 |
+
num_workers: 16
|
11 |
+
|
12 |
+
optim:
|
13 |
+
epochs: 15
|
14 |
+
learning_rate: 7.5e-5
|
15 |
+
num_warmup_updates: 20000 # warmup steps
|
16 |
+
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
|
17 |
+
max_grad_norm: 1.0 # gradient clipping
|
18 |
+
bnb_optimizer: False # use bnb 8bit AdamW optimizer or not
|
19 |
+
|
20 |
+
model:
|
21 |
+
name: F5TTS_Base # model name
|
22 |
+
tokenizer: pinyin # tokenizer type
|
23 |
+
tokenizer_path: None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
|
24 |
+
arch:
|
25 |
+
dim: 1024
|
26 |
+
depth: 22
|
27 |
+
heads: 16
|
28 |
+
ff_mult: 2
|
29 |
+
text_dim: 512
|
30 |
+
conv_layers: 4
|
31 |
+
checkpoint_activations: False # recompute activations and save memory for extra compute
|
32 |
+
mel_spec:
|
33 |
+
target_sample_rate: 24000
|
34 |
+
n_mel_channels: 100
|
35 |
+
hop_length: 256
|
36 |
+
win_length: 1024
|
37 |
+
n_fft: 1024
|
38 |
+
mel_spec_type: vocos # 'vocos' or 'bigvgan'
|
39 |
+
vocoder:
|
40 |
+
is_local: False # use local offline ckpt or not
|
41 |
+
local_path: None # local vocoder path
|
42 |
+
|
43 |
+
ckpts:
|
44 |
+
logger: wandb # wandb | tensorboard | None
|
45 |
+
save_per_updates: 50000 # save checkpoint per steps
|
46 |
+
last_per_steps: 5000 # save last checkpoint per steps
|
47 |
+
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
|
src/f5_tts/configs/F5TTS_Small_train.yaml
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
hydra:
|
2 |
+
run:
|
3 |
+
dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
|
4 |
+
|
5 |
+
datasets:
|
6 |
+
name: Emilia_ZH_EN
|
7 |
+
batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
|
8 |
+
batch_size_type: frame # "frame" or "sample"
|
9 |
+
max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
|
10 |
+
num_workers: 16
|
11 |
+
|
12 |
+
optim:
|
13 |
+
epochs: 15
|
14 |
+
learning_rate: 7.5e-5
|
15 |
+
num_warmup_updates: 20000 # warmup steps
|
16 |
+
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
|
17 |
+
max_grad_norm: 1.0 # gradient clipping
|
18 |
+
bnb_optimizer: False # use bnb 8bit AdamW optimizer or not
|
19 |
+
|
20 |
+
model:
|
21 |
+
name: F5TTS_Small
|
22 |
+
tokenizer: pinyin
|
23 |
+
tokenizer_path: None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
|
24 |
+
arch:
|
25 |
+
dim: 768
|
26 |
+
depth: 18
|
27 |
+
heads: 12
|
28 |
+
ff_mult: 2
|
29 |
+
text_dim: 512
|
30 |
+
conv_layers: 4
|
31 |
+
checkpoint_activations: False # recompute activations and save memory for extra compute
|
32 |
+
mel_spec:
|
33 |
+
target_sample_rate: 24000
|
34 |
+
n_mel_channels: 100
|
35 |
+
hop_length: 256
|
36 |
+
win_length: 1024
|
37 |
+
n_fft: 1024
|
38 |
+
mel_spec_type: vocos # 'vocos' or 'bigvgan'
|
39 |
+
vocoder:
|
40 |
+
is_local: False # use local offline ckpt or not
|
41 |
+
local_path: None # local vocoder path
|
42 |
+
|
43 |
+
ckpts:
|
44 |
+
logger: wandb # wandb | tensorboard | None
|
45 |
+
save_per_updates: 50000 # save checkpoint per steps
|
46 |
+
last_per_steps: 5000 # save last checkpoint per steps
|
47 |
+
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
|
src/f5_tts/eval/README.md
CHANGED
@@ -39,11 +39,14 @@ Then update in the following scripts with the paths you put evaluation model ckp
|
|
39 |
|
40 |
### Objective Evaluation
|
41 |
|
42 |
-
Update the path with your batch-inferenced results, and carry out WER / SIM evaluations:
|
43 |
```bash
|
44 |
-
# Evaluation for Seed-TTS test set
|
45 |
-
python src/f5_tts/eval/eval_seedtts_testset.py
|
46 |
|
47 |
-
# Evaluation for LibriSpeech-PC test-clean (cross-sentence)
|
48 |
-
python src/f5_tts/eval/eval_librispeech_test_clean.py
|
49 |
-
|
|
|
|
|
|
|
|
39 |
|
40 |
### Objective Evaluation
|
41 |
|
42 |
+
Update the path with your batch-inferenced results, and carry out WER / SIM / UTMOS evaluations:
|
43 |
```bash
|
44 |
+
# Evaluation [WER] for Seed-TTS test [ZH] set
|
45 |
+
python src/f5_tts/eval/eval_seedtts_testset.py --eval_task wer --lang zh --gen_wav_dir <GEN_WAV_DIR> --gpu_nums 8
|
46 |
|
47 |
+
# Evaluation [SIM] for LibriSpeech-PC test-clean (cross-sentence)
|
48 |
+
python src/f5_tts/eval/eval_librispeech_test_clean.py --eval_task sim --gen_wav_dir <GEN_WAV_DIR> --librispeech_test_clean_path <TEST_CLEAN_PATH>
|
49 |
+
|
50 |
+
# Evaluation [UTMOS]. --ext: Audio extension
|
51 |
+
python src/f5_tts/eval/eval_utmos.py --audio_dir <WAV_DIR> --ext wav
|
52 |
+
```
|
src/f5_tts/eval/eval_infer_batch.py
CHANGED
@@ -34,8 +34,6 @@ win_length = 1024
|
|
34 |
n_fft = 1024
|
35 |
target_rms = 0.1
|
36 |
|
37 |
-
|
38 |
-
tokenizer = "pinyin"
|
39 |
rel_path = str(files("f5_tts").joinpath("../../"))
|
40 |
|
41 |
|
@@ -49,6 +47,7 @@ def main():
|
|
49 |
parser.add_argument("-n", "--expname", required=True)
|
50 |
parser.add_argument("-c", "--ckptstep", default=1200000, type=int)
|
51 |
parser.add_argument("-m", "--mel_spec_type", default="vocos", type=str, choices=["bigvgan", "vocos"])
|
|
|
52 |
|
53 |
parser.add_argument("-nfe", "--nfestep", default=32, type=int)
|
54 |
parser.add_argument("-o", "--odemethod", default="euler")
|
@@ -64,6 +63,7 @@ def main():
|
|
64 |
ckpt_step = args.ckptstep
|
65 |
ckpt_path = rel_path + f"/ckpts/{exp_name}/model_{ckpt_step}.pt"
|
66 |
mel_spec_type = args.mel_spec_type
|
|
|
67 |
|
68 |
nfe_step = args.nfestep
|
69 |
ode_method = args.odemethod
|
|
|
34 |
n_fft = 1024
|
35 |
target_rms = 0.1
|
36 |
|
|
|
|
|
37 |
rel_path = str(files("f5_tts").joinpath("../../"))
|
38 |
|
39 |
|
|
|
47 |
parser.add_argument("-n", "--expname", required=True)
|
48 |
parser.add_argument("-c", "--ckptstep", default=1200000, type=int)
|
49 |
parser.add_argument("-m", "--mel_spec_type", default="vocos", type=str, choices=["bigvgan", "vocos"])
|
50 |
+
parser.add_argument("-to", "--tokenizer", default="pinyin", type=str, choices=["pinyin", "char"])
|
51 |
|
52 |
parser.add_argument("-nfe", "--nfestep", default=32, type=int)
|
53 |
parser.add_argument("-o", "--odemethod", default="euler")
|
|
|
63 |
ckpt_step = args.ckptstep
|
64 |
ckpt_path = rel_path + f"/ckpts/{exp_name}/model_{ckpt_step}.pt"
|
65 |
mel_spec_type = args.mel_spec_type
|
66 |
+
tokenizer = args.tokenizer
|
67 |
|
68 |
nfe_step = args.nfestep
|
69 |
ode_method = args.odemethod
|
src/f5_tts/eval/eval_librispeech_test_clean.py
CHANGED
@@ -1,7 +1,9 @@
|
|
1 |
# Evaluate with Librispeech test-clean, ~3s prompt to generate 4-10s audio (the way of valle/voicebox evaluation)
|
2 |
|
3 |
-
import
|
|
|
4 |
import os
|
|
|
5 |
|
6 |
sys.path.append(os.getcwd())
|
7 |
|
@@ -9,7 +11,6 @@ import multiprocessing as mp
|
|
9 |
from importlib.resources import files
|
10 |
|
11 |
import numpy as np
|
12 |
-
|
13 |
from f5_tts.eval.utils_eval import (
|
14 |
get_librispeech_test,
|
15 |
run_asr_wer,
|
@@ -19,55 +20,77 @@ from f5_tts.eval.utils_eval import (
|
|
19 |
rel_path = str(files("f5_tts").joinpath("../../"))
|
20 |
|
21 |
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
# Evaluate with Librispeech test-clean, ~3s prompt to generate 4-10s audio (the way of valle/voicebox evaluation)
|
2 |
|
3 |
+
import argparse
|
4 |
+
import json
|
5 |
import os
|
6 |
+
import sys
|
7 |
|
8 |
sys.path.append(os.getcwd())
|
9 |
|
|
|
11 |
from importlib.resources import files
|
12 |
|
13 |
import numpy as np
|
|
|
14 |
from f5_tts.eval.utils_eval import (
|
15 |
get_librispeech_test,
|
16 |
run_asr_wer,
|
|
|
20 |
rel_path = str(files("f5_tts").joinpath("../../"))
|
21 |
|
22 |
|
23 |
+
def get_args():
|
24 |
+
parser = argparse.ArgumentParser()
|
25 |
+
parser.add_argument("-e", "--eval_task", type=str, default="wer", choices=["sim", "wer"])
|
26 |
+
parser.add_argument("-l", "--lang", type=str, default="en")
|
27 |
+
parser.add_argument("-g", "--gen_wav_dir", type=str, required=True)
|
28 |
+
parser.add_argument("-p", "--librispeech_test_clean_path", type=str, required=True)
|
29 |
+
parser.add_argument("-n", "--gpu_nums", type=int, default=8, help="Number of GPUs to use")
|
30 |
+
parser.add_argument("--local", action="store_true", help="Use local custom checkpoint directory")
|
31 |
+
return parser.parse_args()
|
32 |
+
|
33 |
+
|
34 |
+
def main():
|
35 |
+
args = get_args()
|
36 |
+
eval_task = args.eval_task
|
37 |
+
lang = args.lang
|
38 |
+
librispeech_test_clean_path = args.librispeech_test_clean_path # test-clean path
|
39 |
+
gen_wav_dir = args.gen_wav_dir
|
40 |
+
metalst = rel_path + "/data/librispeech_pc_test_clean_cross_sentence.lst"
|
41 |
+
|
42 |
+
gpus = list(range(args.gpu_nums))
|
43 |
+
test_set = get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path)
|
44 |
+
|
45 |
+
## In LibriSpeech, some speakers utilized varying voice characteristics for different characters in the book,
|
46 |
+
## leading to a low similarity for the ground truth in some cases.
|
47 |
+
# test_set = get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path, eval_ground_truth = True) # eval ground truth
|
48 |
+
|
49 |
+
local = args.local
|
50 |
+
if local: # use local custom checkpoint dir
|
51 |
+
asr_ckpt_dir = "../checkpoints/Systran/faster-whisper-large-v3"
|
52 |
+
else:
|
53 |
+
asr_ckpt_dir = "" # auto download to cache dir
|
54 |
+
wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth"
|
55 |
+
|
56 |
+
# --------------------------- WER ---------------------------
|
57 |
+
|
58 |
+
if eval_task == "wer":
|
59 |
+
wer_results = []
|
60 |
+
wers = []
|
61 |
+
|
62 |
+
with mp.Pool(processes=len(gpus)) as pool:
|
63 |
+
args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
|
64 |
+
results = pool.map(run_asr_wer, args)
|
65 |
+
for r in results:
|
66 |
+
wer_results.extend(r)
|
67 |
+
|
68 |
+
wer_result_path = f"{gen_wav_dir}/{lang}_wer_results.jsonl"
|
69 |
+
with open(wer_result_path, "w") as f:
|
70 |
+
for line in wer_results:
|
71 |
+
wers.append(line["wer"])
|
72 |
+
json_line = json.dumps(line, ensure_ascii=False)
|
73 |
+
f.write(json_line + "\n")
|
74 |
+
|
75 |
+
wer = round(np.mean(wers) * 100, 3)
|
76 |
+
print(f"\nTotal {len(wers)} samples")
|
77 |
+
print(f"WER : {wer}%")
|
78 |
+
print(f"Results have been saved to {wer_result_path}")
|
79 |
+
|
80 |
+
# --------------------------- SIM ---------------------------
|
81 |
+
|
82 |
+
if eval_task == "sim":
|
83 |
+
sims = []
|
84 |
+
with mp.Pool(processes=len(gpus)) as pool:
|
85 |
+
args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set]
|
86 |
+
results = pool.map(run_sim, args)
|
87 |
+
for r in results:
|
88 |
+
sims.extend(r)
|
89 |
+
|
90 |
+
sim = round(sum(sims) / len(sims), 3)
|
91 |
+
print(f"\nTotal {len(sims)} samples")
|
92 |
+
print(f"SIM : {sim}")
|
93 |
+
|
94 |
+
|
95 |
+
if __name__ == "__main__":
|
96 |
+
main()
|
src/f5_tts/eval/eval_seedtts_testset.py
CHANGED
@@ -1,7 +1,9 @@
|
|
1 |
# Evaluate with Seed-TTS testset
|
2 |
|
3 |
-
import
|
|
|
4 |
import os
|
|
|
5 |
|
6 |
sys.path.append(os.getcwd())
|
7 |
|
@@ -9,7 +11,6 @@ import multiprocessing as mp
|
|
9 |
from importlib.resources import files
|
10 |
|
11 |
import numpy as np
|
12 |
-
|
13 |
from f5_tts.eval.utils_eval import (
|
14 |
get_seed_tts_test,
|
15 |
run_asr_wer,
|
@@ -19,57 +20,76 @@ from f5_tts.eval.utils_eval import (
|
|
19 |
rel_path = str(files("f5_tts").joinpath("../../"))
|
20 |
|
21 |
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
# Evaluate with Seed-TTS testset
|
2 |
|
3 |
+
import argparse
|
4 |
+
import json
|
5 |
import os
|
6 |
+
import sys
|
7 |
|
8 |
sys.path.append(os.getcwd())
|
9 |
|
|
|
11 |
from importlib.resources import files
|
12 |
|
13 |
import numpy as np
|
|
|
14 |
from f5_tts.eval.utils_eval import (
|
15 |
get_seed_tts_test,
|
16 |
run_asr_wer,
|
|
|
20 |
rel_path = str(files("f5_tts").joinpath("../../"))
|
21 |
|
22 |
|
23 |
+
def get_args():
|
24 |
+
parser = argparse.ArgumentParser()
|
25 |
+
parser.add_argument("-e", "--eval_task", type=str, default="wer", choices=["sim", "wer"])
|
26 |
+
parser.add_argument("-l", "--lang", type=str, default="en", choices=["zh", "en"])
|
27 |
+
parser.add_argument("-g", "--gen_wav_dir", type=str, required=True)
|
28 |
+
parser.add_argument("-n", "--gpu_nums", type=int, default=8, help="Number of GPUs to use")
|
29 |
+
parser.add_argument("--local", action="store_true", help="Use local custom checkpoint directory")
|
30 |
+
return parser.parse_args()
|
31 |
+
|
32 |
+
|
33 |
+
def main():
|
34 |
+
args = get_args()
|
35 |
+
eval_task = args.eval_task
|
36 |
+
lang = args.lang
|
37 |
+
gen_wav_dir = args.gen_wav_dir
|
38 |
+
metalst = rel_path + f"/data/seedtts_testset/{lang}/meta.lst" # seed-tts testset
|
39 |
+
|
40 |
+
# NOTE. paraformer-zh result will be slightly different according to the number of gpus, cuz batchsize is different
|
41 |
+
# zh 1.254 seems a result of 4 workers wer_seed_tts
|
42 |
+
gpus = list(range(args.gpu_nums))
|
43 |
+
test_set = get_seed_tts_test(metalst, gen_wav_dir, gpus)
|
44 |
+
|
45 |
+
local = args.local
|
46 |
+
if local: # use local custom checkpoint dir
|
47 |
+
if lang == "zh":
|
48 |
+
asr_ckpt_dir = "../checkpoints/funasr" # paraformer-zh dir under funasr
|
49 |
+
elif lang == "en":
|
50 |
+
asr_ckpt_dir = "../checkpoints/Systran/faster-whisper-large-v3"
|
51 |
+
else:
|
52 |
+
asr_ckpt_dir = "" # auto download to cache dir
|
53 |
+
wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth"
|
54 |
+
|
55 |
+
# --------------------------- WER ---------------------------
|
56 |
+
|
57 |
+
if eval_task == "wer":
|
58 |
+
wer_results = []
|
59 |
+
wers = []
|
60 |
+
|
61 |
+
with mp.Pool(processes=len(gpus)) as pool:
|
62 |
+
args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
|
63 |
+
results = pool.map(run_asr_wer, args)
|
64 |
+
for r in results:
|
65 |
+
wer_results.extend(r)
|
66 |
+
|
67 |
+
wer_result_path = f"{gen_wav_dir}/{lang}_wer_results.jsonl"
|
68 |
+
with open(wer_result_path, "w") as f:
|
69 |
+
for line in wer_results:
|
70 |
+
wers.append(line["wer"])
|
71 |
+
json_line = json.dumps(line, ensure_ascii=False)
|
72 |
+
f.write(json_line + "\n")
|
73 |
+
|
74 |
+
wer = round(np.mean(wers) * 100, 3)
|
75 |
+
print(f"\nTotal {len(wers)} samples")
|
76 |
+
print(f"WER : {wer}%")
|
77 |
+
print(f"Results have been saved to {wer_result_path}")
|
78 |
+
|
79 |
+
# --------------------------- SIM ---------------------------
|
80 |
+
|
81 |
+
if eval_task == "sim":
|
82 |
+
sims = []
|
83 |
+
with mp.Pool(processes=len(gpus)) as pool:
|
84 |
+
args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set]
|
85 |
+
results = pool.map(run_sim, args)
|
86 |
+
for r in results:
|
87 |
+
sims.extend(r)
|
88 |
+
|
89 |
+
sim = round(sum(sims) / len(sims), 3)
|
90 |
+
print(f"\nTotal {len(sims)} samples")
|
91 |
+
print(f"SIM : {sim}")
|
92 |
+
|
93 |
+
|
94 |
+
if __name__ == "__main__":
|
95 |
+
main()
|
src/f5_tts/eval/eval_utmos.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
import librosa
|
6 |
+
import torch
|
7 |
+
from tqdm import tqdm
|
8 |
+
|
9 |
+
|
10 |
+
def main():
|
11 |
+
parser = argparse.ArgumentParser(description="UTMOS Evaluation")
|
12 |
+
parser.add_argument("--audio_dir", type=str, required=True, help="Audio file path.")
|
13 |
+
parser.add_argument("--ext", type=str, default="wav", help="Audio extension.")
|
14 |
+
args = parser.parse_args()
|
15 |
+
|
16 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
17 |
+
|
18 |
+
predictor = torch.hub.load("tarepan/SpeechMOS:v1.2.0", "utmos22_strong", trust_repo=True)
|
19 |
+
predictor = predictor.to(device)
|
20 |
+
|
21 |
+
audio_paths = list(Path(args.audio_dir).rglob(f"*.{args.ext}"))
|
22 |
+
utmos_results = {}
|
23 |
+
utmos_score = 0
|
24 |
+
|
25 |
+
for audio_path in tqdm(audio_paths, desc="Processing"):
|
26 |
+
wav_name = audio_path.stem
|
27 |
+
wav, sr = librosa.load(audio_path, sr=None, mono=True)
|
28 |
+
wav_tensor = torch.from_numpy(wav).to(device).unsqueeze(0)
|
29 |
+
score = predictor(wav_tensor, sr)
|
30 |
+
utmos_results[str(wav_name)] = score.item()
|
31 |
+
utmos_score += score.item()
|
32 |
+
|
33 |
+
avg_score = utmos_score / len(audio_paths) if len(audio_paths) > 0 else 0
|
34 |
+
print(f"UTMOS: {avg_score}")
|
35 |
+
|
36 |
+
utmos_result_path = Path(args.audio_dir) / "utmos_results.json"
|
37 |
+
with open(utmos_result_path, "w", encoding="utf-8") as f:
|
38 |
+
json.dump(utmos_results, f, ensure_ascii=False, indent=4)
|
39 |
+
|
40 |
+
print(f"Results have been saved to {utmos_result_path}")
|
41 |
+
|
42 |
+
|
43 |
+
if __name__ == "__main__":
|
44 |
+
main()
|
src/f5_tts/eval/utils_eval.py
CHANGED
@@ -2,6 +2,7 @@ import math
|
|
2 |
import os
|
3 |
import random
|
4 |
import string
|
|
|
5 |
|
6 |
import torch
|
7 |
import torch.nn.functional as F
|
@@ -320,7 +321,7 @@ def run_asr_wer(args):
|
|
320 |
from zhon.hanzi import punctuation
|
321 |
|
322 |
punctuation_all = punctuation + string.punctuation
|
323 |
-
|
324 |
|
325 |
from jiwer import compute_measures
|
326 |
|
@@ -335,8 +336,8 @@ def run_asr_wer(args):
|
|
335 |
for segment in segments:
|
336 |
hypo = hypo + " " + segment.text
|
337 |
|
338 |
-
|
339 |
-
|
340 |
|
341 |
for x in punctuation_all:
|
342 |
truth = truth.replace(x, "")
|
@@ -360,9 +361,16 @@ def run_asr_wer(args):
|
|
360 |
# dele = measures["deletions"] / len(ref_list)
|
361 |
# inse = measures["insertions"] / len(ref_list)
|
362 |
|
363 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
364 |
|
365 |
-
return
|
366 |
|
367 |
|
368 |
# SIM Evaluation
|
@@ -381,7 +389,7 @@ def run_sim(args):
|
|
381 |
model = model.cuda(device)
|
382 |
model.eval()
|
383 |
|
384 |
-
|
385 |
for wav1, wav2, truth in tqdm(test_set):
|
386 |
wav1, sr1 = torchaudio.load(wav1)
|
387 |
wav2, sr2 = torchaudio.load(wav2)
|
@@ -400,6 +408,6 @@ def run_sim(args):
|
|
400 |
|
401 |
sim = F.cosine_similarity(emb1, emb2)[0].item()
|
402 |
# print(f"VSim score between two audios: {sim:.4f} (-1.0, 1.0).")
|
403 |
-
|
404 |
|
405 |
-
return
|
|
|
2 |
import os
|
3 |
import random
|
4 |
import string
|
5 |
+
from pathlib import Path
|
6 |
|
7 |
import torch
|
8 |
import torch.nn.functional as F
|
|
|
321 |
from zhon.hanzi import punctuation
|
322 |
|
323 |
punctuation_all = punctuation + string.punctuation
|
324 |
+
wer_results = []
|
325 |
|
326 |
from jiwer import compute_measures
|
327 |
|
|
|
336 |
for segment in segments:
|
337 |
hypo = hypo + " " + segment.text
|
338 |
|
339 |
+
raw_truth = truth
|
340 |
+
raw_hypo = hypo
|
341 |
|
342 |
for x in punctuation_all:
|
343 |
truth = truth.replace(x, "")
|
|
|
361 |
# dele = measures["deletions"] / len(ref_list)
|
362 |
# inse = measures["insertions"] / len(ref_list)
|
363 |
|
364 |
+
wer_results.append(
|
365 |
+
{
|
366 |
+
"wav": Path(gen_wav).stem,
|
367 |
+
"truth": raw_truth,
|
368 |
+
"hypo": raw_hypo,
|
369 |
+
"wer": wer,
|
370 |
+
}
|
371 |
+
)
|
372 |
|
373 |
+
return wer_results
|
374 |
|
375 |
|
376 |
# SIM Evaluation
|
|
|
389 |
model = model.cuda(device)
|
390 |
model.eval()
|
391 |
|
392 |
+
sims = []
|
393 |
for wav1, wav2, truth in tqdm(test_set):
|
394 |
wav1, sr1 = torchaudio.load(wav1)
|
395 |
wav2, sr2 = torchaudio.load(wav2)
|
|
|
408 |
|
409 |
sim = F.cosine_similarity(emb1, emb2)[0].item()
|
410 |
# print(f"VSim score between two audios: {sim:.4f} (-1.0, 1.0).")
|
411 |
+
sims.append(sim)
|
412 |
|
413 |
+
return sims
|
src/f5_tts/infer/README.md
CHANGED
@@ -12,6 +12,8 @@ To avoid possible inference failures, make sure you have seen through the follow
|
|
12 |
- Uppercased letters will be uttered letter by letter, so use lowercased letters for normal words.
|
13 |
- Add some spaces (blank: " ") or punctuations (e.g. "," ".") to explicitly introduce some pauses.
|
14 |
- Preprocess numbers to Chinese letters if you want to have them read in Chinese, otherwise in English.
|
|
|
|
|
15 |
|
16 |
|
17 |
## Gradio App
|
@@ -62,6 +64,9 @@ f5-tts_infer-cli \
|
|
62 |
# Choose Vocoder
|
63 |
f5-tts_infer-cli --vocoder_name bigvgan --load_vocoder_from_local --ckpt_file <YOUR_CKPT_PATH, eg:ckpts/F5TTS_Base_bigvgan/model_1250000.pt>
|
64 |
f5-tts_infer-cli --vocoder_name vocos --load_vocoder_from_local --ckpt_file <YOUR_CKPT_PATH, eg:ckpts/F5TTS_Base/model_1200000.safetensors>
|
|
|
|
|
|
|
65 |
```
|
66 |
|
67 |
And a `.toml` file would help with more flexible usage.
|
|
|
12 |
- Uppercased letters will be uttered letter by letter, so use lowercased letters for normal words.
|
13 |
- Add some spaces (blank: " ") or punctuations (e.g. "," ".") to explicitly introduce some pauses.
|
14 |
- Preprocess numbers to Chinese letters if you want to have them read in Chinese, otherwise in English.
|
15 |
+
- If the generation output is blank (pure silence), check for ffmpeg installation (various tutorials online, blogs, videos, etc.).
|
16 |
+
- Try turn off use_ema if using an early-stage finetuned checkpoint (which goes just few updates).
|
17 |
|
18 |
|
19 |
## Gradio App
|
|
|
64 |
# Choose Vocoder
|
65 |
f5-tts_infer-cli --vocoder_name bigvgan --load_vocoder_from_local --ckpt_file <YOUR_CKPT_PATH, eg:ckpts/F5TTS_Base_bigvgan/model_1250000.pt>
|
66 |
f5-tts_infer-cli --vocoder_name vocos --load_vocoder_from_local --ckpt_file <YOUR_CKPT_PATH, eg:ckpts/F5TTS_Base/model_1200000.safetensors>
|
67 |
+
|
68 |
+
# More instructions
|
69 |
+
f5-tts_infer-cli --help
|
70 |
```
|
71 |
|
72 |
And a `.toml` file would help with more flexible usage.
|
src/f5_tts/infer/SHARED.md
CHANGED
@@ -16,59 +16,131 @@
|
|
16 |
<!-- omit in toc -->
|
17 |
### Supported Languages
|
18 |
- [Multilingual](#multilingual)
|
19 |
-
- [F5-TTS Base @
|
20 |
-
- [Mandarin](#mandarin)
|
21 |
-
- [Japanese](#japanese)
|
22 |
-
- [F5-TTS Base @ pretrain/finetune @ ja](#f5-tts-base--pretrainfinetune--ja)
|
23 |
- [English](#english)
|
|
|
|
|
24 |
- [French](#french)
|
25 |
-
- [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
|
28 |
## Multilingual
|
29 |
|
30 |
-
#### F5-TTS Base @
|
31 |
|Model|🤗Hugging Face|Data (Hours)|Model License|
|
32 |
|:---:|:------------:|:-----------:|:-------------:|
|
33 |
|F5-TTS Base|[ckpt & vocab](https://huggingface.co/SWivid/F5-TTS/tree/main/F5TTS_Base)|[Emilia 95K zh&en](https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07)|cc-by-nc-4.0|
|
34 |
|
35 |
```bash
|
36 |
-
|
37 |
-
|
|
|
38 |
```
|
39 |
|
40 |
*Other infos, e.g. Author info, Github repo, Link to some sampled results, Usage instruction, Tutorial (Blog, Video, etc.) ...*
|
41 |
|
42 |
|
43 |
-
##
|
44 |
|
45 |
-
## Japanese
|
46 |
|
47 |
-
|
48 |
-
|
|
|
|
|
49 |
|:---:|:------------:|:-----------:|:-------------:|
|
50 |
-
|F5-TTS Base|[ckpt & vocab](https://huggingface.co/
|
51 |
|
52 |
```bash
|
53 |
-
|
54 |
-
|
|
|
55 |
```
|
56 |
|
57 |
-
## English
|
58 |
-
|
59 |
|
60 |
## French
|
61 |
|
62 |
-
####
|
63 |
|Model|🤗Hugging Face|Data (Hours)|Model License|
|
64 |
|:---:|:------------:|:-----------:|:-------------:|
|
65 |
-
|F5-TTS
|
66 |
|
67 |
```bash
|
68 |
-
|
69 |
-
|
|
|
70 |
```
|
71 |
|
72 |
- [Online Inference with Hugging Face Space](https://huggingface.co/spaces/RASPIAUDIO/f5-tts_french).
|
73 |
- [Tutorial video to train a new language model](https://www.youtube.com/watch?v=UO4usaOojys).
|
74 |
- [Discussion about this training can be found here](https://github.com/SWivid/F5-TTS/issues/434).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
<!-- omit in toc -->
|
17 |
### Supported Languages
|
18 |
- [Multilingual](#multilingual)
|
19 |
+
- [F5-TTS Base @ zh \& en @ F5-TTS](#f5-tts-base--zh--en--f5-tts)
|
|
|
|
|
|
|
20 |
- [English](#english)
|
21 |
+
- [Finnish](#finnish)
|
22 |
+
- [F5-TTS Base @ fi @ AsmoKoskinen](#f5-tts-base--fi--asmokoskinen)
|
23 |
- [French](#french)
|
24 |
+
- [F5-TTS Base @ fr @ RASPIAUDIO](#f5-tts-base--fr--raspiaudio)
|
25 |
+
- [Hindi](#hindi)
|
26 |
+
- [F5-TTS Small @ hi @ SPRINGLab](#f5-tts-small--hi--springlab)
|
27 |
+
- [Italian](#italian)
|
28 |
+
- [F5-TTS Base @ it @ alien79](#f5-tts-base--it--alien79)
|
29 |
+
- [Japanese](#japanese)
|
30 |
+
- [F5-TTS Base @ ja @ Jmica](#f5-tts-base--ja--jmica)
|
31 |
+
- [Mandarin](#mandarin)
|
32 |
+
- [Spanish](#spanish)
|
33 |
+
- [F5-TTS Base @ es @ jpgallegoar](#f5-tts-base--es--jpgallegoar)
|
34 |
|
35 |
|
36 |
## Multilingual
|
37 |
|
38 |
+
#### F5-TTS Base @ zh & en @ F5-TTS
|
39 |
|Model|🤗Hugging Face|Data (Hours)|Model License|
|
40 |
|:---:|:------------:|:-----------:|:-------------:|
|
41 |
|F5-TTS Base|[ckpt & vocab](https://huggingface.co/SWivid/F5-TTS/tree/main/F5TTS_Base)|[Emilia 95K zh&en](https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07)|cc-by-nc-4.0|
|
42 |
|
43 |
```bash
|
44 |
+
Model: hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors
|
45 |
+
Vocab: hf://SWivid/F5-TTS/F5TTS_Base/vocab.txt
|
46 |
+
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
|
47 |
```
|
48 |
|
49 |
*Other infos, e.g. Author info, Github repo, Link to some sampled results, Usage instruction, Tutorial (Blog, Video, etc.) ...*
|
50 |
|
51 |
|
52 |
+
## English
|
53 |
|
|
|
54 |
|
55 |
+
## Finnish
|
56 |
+
|
57 |
+
#### F5-TTS Base @ fi @ AsmoKoskinen
|
58 |
+
|Model|🤗Hugging Face|Data|Model License|
|
59 |
|:---:|:------------:|:-----------:|:-------------:|
|
60 |
+
|F5-TTS Base|[ckpt & vocab](https://huggingface.co/AsmoKoskinen/F5-TTS_Finnish_Model)|[Common Voice](https://huggingface.co/datasets/mozilla-foundation/common_voice_17_0), [Vox Populi](https://huggingface.co/datasets/facebook/voxpopuli)|cc-by-nc-4.0|
|
61 |
|
62 |
```bash
|
63 |
+
Model: hf://AsmoKoskinen/F5-TTS_Finnish_Model/model_common_voice_fi_vox_populi_fi_20241206.safetensors
|
64 |
+
Vocab: hf://AsmoKoskinen/F5-TTS_Finnish_Model/vocab.txt
|
65 |
+
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
|
66 |
```
|
67 |
|
|
|
|
|
68 |
|
69 |
## French
|
70 |
|
71 |
+
#### F5-TTS Base @ fr @ RASPIAUDIO
|
72 |
|Model|🤗Hugging Face|Data (Hours)|Model License|
|
73 |
|:---:|:------------:|:-----------:|:-------------:|
|
74 |
+
|F5-TTS Base|[ckpt & vocab](https://huggingface.co/RASPIAUDIO/F5-French-MixedSpeakers-reduced)|[LibriVox](https://librivox.org/)|cc-by-nc-4.0|
|
75 |
|
76 |
```bash
|
77 |
+
Model: hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/model_last_reduced.pt
|
78 |
+
Vocab: hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/vocab.txt
|
79 |
+
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
|
80 |
```
|
81 |
|
82 |
- [Online Inference with Hugging Face Space](https://huggingface.co/spaces/RASPIAUDIO/f5-tts_french).
|
83 |
- [Tutorial video to train a new language model](https://www.youtube.com/watch?v=UO4usaOojys).
|
84 |
- [Discussion about this training can be found here](https://github.com/SWivid/F5-TTS/issues/434).
|
85 |
+
|
86 |
+
|
87 |
+
## Hindi
|
88 |
+
|
89 |
+
#### F5-TTS Small @ hi @ SPRINGLab
|
90 |
+
|Model|🤗Hugging Face|Data (Hours)|Model License|
|
91 |
+
|:---:|:------------:|:-----------:|:-------------:|
|
92 |
+
|F5-TTS Small|[ckpt & vocab](https://huggingface.co/SPRINGLab/F5-Hindi-24KHz)|[IndicTTS Hi](https://huggingface.co/datasets/SPRINGLab/IndicTTS-Hindi) & [IndicVoices-R Hi](https://huggingface.co/datasets/SPRINGLab/IndicVoices-R_Hindi) |cc-by-4.0|
|
93 |
+
|
94 |
+
```bash
|
95 |
+
Model: hf://SPRINGLab/F5-Hindi-24KHz/model_2500000.safetensors
|
96 |
+
Vocab: hf://SPRINGLab/F5-Hindi-24KHz/vocab.txt
|
97 |
+
Config: {"dim": 768, "depth": 18, "heads": 12, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
|
98 |
+
```
|
99 |
+
|
100 |
+
- Authors: SPRING Lab, Indian Institute of Technology, Madras
|
101 |
+
- Website: https://asr.iitm.ac.in/
|
102 |
+
|
103 |
+
|
104 |
+
## Italian
|
105 |
+
|
106 |
+
#### F5-TTS Base @ it @ alien79
|
107 |
+
|Model|🤗Hugging Face|Data|Model License|
|
108 |
+
|:---:|:------------:|:-----------:|:-------------:|
|
109 |
+
|F5-TTS Base|[ckpt & vocab](https://huggingface.co/alien79/F5-TTS-italian)|[ylacombe/cml-tts](https://huggingface.co/datasets/ylacombe/cml-tts) |cc-by-nc-4.0|
|
110 |
+
|
111 |
+
```bash
|
112 |
+
Model: hf://alien79/F5-TTS-italian/model_159600.safetensors
|
113 |
+
Vocab: hf://alien79/F5-TTS-italian/vocab.txt
|
114 |
+
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
|
115 |
+
```
|
116 |
+
|
117 |
+
- Trained by [Mithril Man](https://github.com/MithrilMan)
|
118 |
+
- Model details on [hf project home](https://huggingface.co/alien79/F5-TTS-italian)
|
119 |
+
- Open to collaborations to further improve the model
|
120 |
+
|
121 |
+
|
122 |
+
## Japanese
|
123 |
+
|
124 |
+
#### F5-TTS Base @ ja @ Jmica
|
125 |
+
|Model|🤗Hugging Face|Data (Hours)|Model License|
|
126 |
+
|:---:|:------------:|:-----------:|:-------------:|
|
127 |
+
|F5-TTS Base|[ckpt & vocab](https://huggingface.co/Jmica/F5TTS/tree/main/JA_8500000)|[Emilia 1.7k JA](https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07) & [Galgame Dataset 5.4k](https://huggingface.co/datasets/OOPPEENN/Galgame_Dataset)|cc-by-nc-4.0|
|
128 |
+
|
129 |
+
```bash
|
130 |
+
Model: hf://Jmica/F5TTS/JA_8500000/model_8499660.pt
|
131 |
+
Vocab: hf://Jmica/F5TTS/JA_8500000/vocab_updated.txt
|
132 |
+
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
|
133 |
+
```
|
134 |
+
|
135 |
+
|
136 |
+
## Mandarin
|
137 |
+
|
138 |
+
|
139 |
+
## Spanish
|
140 |
+
|
141 |
+
#### F5-TTS Base @ es @ jpgallegoar
|
142 |
+
|Model|🤗Hugging Face|Data (Hours)|Model License|
|
143 |
+
|:---:|:------------:|:-----------:|:-------------:|
|
144 |
+
|F5-TTS Base|[ckpt & vocab](https://huggingface.co/jpgallegoar/F5-Spanish)|[Voxpopuli](https://huggingface.co/datasets/facebook/voxpopuli) & Crowdsourced & TEDx, 218 hours|cc0-1.0|
|
145 |
+
|
146 |
+
- @jpgallegoar [GitHub repo](https://github.com/jpgallegoar/Spanish-F5), Jupyter Notebook and Gradio usage for Spanish model.
|
src/f5_tts/infer/examples/basic/basic.toml
CHANGED
@@ -8,4 +8,4 @@ gen_text = "I don't really care what you call me. I've been a silent spectator,
|
|
8 |
gen_file = ""
|
9 |
remove_silence = false
|
10 |
output_dir = "tests"
|
11 |
-
output_file = "
|
|
|
8 |
gen_file = ""
|
9 |
remove_silence = false
|
10 |
output_dir = "tests"
|
11 |
+
output_file = "infer_cli_basic.wav"
|
src/f5_tts/infer/examples/multi/story.toml
CHANGED
@@ -8,6 +8,7 @@ gen_text = ""
|
|
8 |
gen_file = "infer/examples/multi/story.txt"
|
9 |
remove_silence = true
|
10 |
output_dir = "tests"
|
|
|
11 |
|
12 |
[voices.town]
|
13 |
ref_audio = "infer/examples/multi/town.flac"
|
|
|
8 |
gen_file = "infer/examples/multi/story.txt"
|
9 |
remove_silence = true
|
10 |
output_dir = "tests"
|
11 |
+
output_file = "infer_cli_story.wav"
|
12 |
|
13 |
[voices.town]
|
14 |
ref_audio = "infer/examples/multi/town.flac"
|
src/f5_tts/infer/infer_cli.py
CHANGED
@@ -2,6 +2,7 @@ import argparse
|
|
2 |
import codecs
|
3 |
import os
|
4 |
import re
|
|
|
5 |
from importlib.resources import files
|
6 |
from pathlib import Path
|
7 |
|
@@ -9,8 +10,17 @@ import numpy as np
|
|
9 |
import soundfile as sf
|
10 |
import tomli
|
11 |
from cached_path import cached_path
|
|
|
12 |
|
13 |
from f5_tts.infer.utils_infer import (
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
infer_process,
|
15 |
load_model,
|
16 |
load_vocoder,
|
@@ -19,6 +29,7 @@ from f5_tts.infer.utils_infer import (
|
|
19 |
)
|
20 |
from f5_tts.model import DiT, UNetT
|
21 |
|
|
|
22 |
parser = argparse.ArgumentParser(
|
23 |
prog="python3 infer-cli.py",
|
24 |
description="Commandline interface for E2/F5 TTS with Advanced Batch Processing.",
|
@@ -27,74 +38,168 @@ parser = argparse.ArgumentParser(
|
|
27 |
parser.add_argument(
|
28 |
"-c",
|
29 |
"--config",
|
30 |
-
|
31 |
default=os.path.join(files("f5_tts").joinpath("infer/examples/basic"), "basic.toml"),
|
|
|
32 |
)
|
|
|
|
|
|
|
|
|
33 |
parser.add_argument(
|
34 |
"-m",
|
35 |
"--model",
|
36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
)
|
38 |
parser.add_argument(
|
39 |
"-p",
|
40 |
"--ckpt_file",
|
41 |
-
|
|
|
42 |
)
|
43 |
parser.add_argument(
|
44 |
"-v",
|
45 |
"--vocab_file",
|
46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
)
|
48 |
-
parser.add_argument("-r", "--ref_audio", type=str, help="Reference audio file < 15 seconds.")
|
49 |
-
parser.add_argument("-s", "--ref_text", type=str, default="666", help="Subtitle for the reference audio.")
|
50 |
parser.add_argument(
|
51 |
"-t",
|
52 |
"--gen_text",
|
53 |
type=str,
|
54 |
-
help="
|
55 |
)
|
56 |
parser.add_argument(
|
57 |
"-f",
|
58 |
"--gen_file",
|
59 |
type=str,
|
60 |
-
help="
|
61 |
)
|
62 |
parser.add_argument(
|
63 |
"-o",
|
64 |
"--output_dir",
|
65 |
type=str,
|
66 |
-
help="
|
67 |
)
|
68 |
parser.add_argument(
|
69 |
"-w",
|
70 |
"--output_file",
|
71 |
type=str,
|
72 |
-
help="
|
|
|
|
|
|
|
|
|
|
|
73 |
)
|
74 |
parser.add_argument(
|
75 |
"--remove_silence",
|
76 |
-
|
|
|
77 |
)
|
78 |
-
parser.add_argument("--vocoder_name", type=str, default="vocos", choices=["vocos", "bigvgan"], help="vocoder name")
|
79 |
parser.add_argument(
|
80 |
"--load_vocoder_from_local",
|
81 |
action="store_true",
|
82 |
-
help="load vocoder from local
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
)
|
84 |
parser.add_argument(
|
85 |
"--speed",
|
86 |
type=float,
|
87 |
-
|
88 |
-
|
|
|
|
|
|
|
|
|
89 |
)
|
90 |
args = parser.parse_args()
|
91 |
|
|
|
|
|
|
|
92 |
config = tomli.load(open(args.config, "rb"))
|
93 |
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
|
99 |
# patches for pip pkg user
|
100 |
if "infer/examples/" in ref_audio:
|
@@ -107,34 +212,39 @@ if "voices" in config:
|
|
107 |
if "infer/examples/" in voice_ref_audio:
|
108 |
config["voices"][voice]["ref_audio"] = str(files("f5_tts").joinpath(f"{voice_ref_audio}"))
|
109 |
|
|
|
|
|
|
|
110 |
if gen_file:
|
111 |
gen_text = codecs.open(gen_file, "r", "utf-8").read()
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
ckpt_file = args.ckpt_file if args.ckpt_file else ""
|
116 |
-
vocab_file = args.vocab_file if args.vocab_file else ""
|
117 |
-
remove_silence = args.remove_silence if args.remove_silence else config["remove_silence"]
|
118 |
-
speed = args.speed
|
119 |
|
120 |
wave_path = Path(output_dir) / output_file
|
121 |
# spectrogram_path = Path(output_dir) / "infer_cli_out.png"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
|
123 |
-
vocoder_name = args.vocoder_name
|
124 |
-
mel_spec_type = args.vocoder_name
|
125 |
if vocoder_name == "vocos":
|
126 |
vocoder_local_path = "../checkpoints/vocos-mel-24khz"
|
127 |
elif vocoder_name == "bigvgan":
|
128 |
vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
|
129 |
|
130 |
-
vocoder = load_vocoder(vocoder_name=
|
|
|
131 |
|
|
|
132 |
|
133 |
-
# load models
|
134 |
if model == "F5-TTS":
|
135 |
model_cls = DiT
|
136 |
-
model_cfg =
|
137 |
-
if ckpt_file
|
138 |
if vocoder_name == "vocos":
|
139 |
repo_name = "F5-TTS"
|
140 |
exp_name = "F5TTS_Base"
|
@@ -148,22 +258,25 @@ if model == "F5-TTS":
|
|
148 |
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.pt"))
|
149 |
|
150 |
elif model == "E2-TTS":
|
151 |
-
assert
|
|
|
152 |
model_cls = UNetT
|
153 |
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
154 |
-
if ckpt_file
|
155 |
repo_name = "E2-TTS"
|
156 |
exp_name = "E2TTS_Base"
|
157 |
ckpt_step = 1200000
|
158 |
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
|
159 |
# ckpt_file = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path
|
160 |
|
161 |
-
|
162 |
print(f"Using {model}...")
|
163 |
-
ema_model = load_model(model_cls, model_cfg, ckpt_file, mel_spec_type=
|
|
|
164 |
|
|
|
165 |
|
166 |
-
|
|
|
167 |
main_voice = {"ref_audio": ref_audio, "ref_text": ref_text}
|
168 |
if "voices" not in config:
|
169 |
voices = {"main": main_voice}
|
@@ -171,16 +284,16 @@ def main_process(ref_audio, ref_text, text_gen, model_obj, mel_spec_type, remove
|
|
171 |
voices = config["voices"]
|
172 |
voices["main"] = main_voice
|
173 |
for voice in voices:
|
|
|
|
|
174 |
voices[voice]["ref_audio"], voices[voice]["ref_text"] = preprocess_ref_audio_text(
|
175 |
voices[voice]["ref_audio"], voices[voice]["ref_text"]
|
176 |
)
|
177 |
-
print("
|
178 |
-
print("Ref_audio:", voices[voice]["ref_audio"])
|
179 |
-
print("Ref_text:", voices[voice]["ref_text"])
|
180 |
|
181 |
generated_audio_segments = []
|
182 |
reg1 = r"(?=\[\w+\])"
|
183 |
-
chunks = re.split(reg1,
|
184 |
reg2 = r"\[(\w+)\]"
|
185 |
for text in chunks:
|
186 |
if not text.strip():
|
@@ -195,14 +308,35 @@ def main_process(ref_audio, ref_text, text_gen, model_obj, mel_spec_type, remove
|
|
195 |
print(f"Voice {voice} not found, using main.")
|
196 |
voice = "main"
|
197 |
text = re.sub(reg2, "", text)
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
print(f"Voice: {voice}")
|
202 |
-
|
203 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
204 |
)
|
205 |
-
generated_audio_segments.append(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
206 |
|
207 |
if generated_audio_segments:
|
208 |
final_wave = np.concatenate(generated_audio_segments)
|
@@ -218,9 +352,5 @@ def main_process(ref_audio, ref_text, text_gen, model_obj, mel_spec_type, remove
|
|
218 |
print(f.name)
|
219 |
|
220 |
|
221 |
-
def main():
|
222 |
-
main_process(ref_audio, ref_text, gen_text, ema_model, mel_spec_type, remove_silence, speed)
|
223 |
-
|
224 |
-
|
225 |
if __name__ == "__main__":
|
226 |
main()
|
|
|
2 |
import codecs
|
3 |
import os
|
4 |
import re
|
5 |
+
from datetime import datetime
|
6 |
from importlib.resources import files
|
7 |
from pathlib import Path
|
8 |
|
|
|
10 |
import soundfile as sf
|
11 |
import tomli
|
12 |
from cached_path import cached_path
|
13 |
+
from omegaconf import OmegaConf
|
14 |
|
15 |
from f5_tts.infer.utils_infer import (
|
16 |
+
mel_spec_type,
|
17 |
+
target_rms,
|
18 |
+
cross_fade_duration,
|
19 |
+
nfe_step,
|
20 |
+
cfg_strength,
|
21 |
+
sway_sampling_coef,
|
22 |
+
speed,
|
23 |
+
fix_duration,
|
24 |
infer_process,
|
25 |
load_model,
|
26 |
load_vocoder,
|
|
|
29 |
)
|
30 |
from f5_tts.model import DiT, UNetT
|
31 |
|
32 |
+
|
33 |
parser = argparse.ArgumentParser(
|
34 |
prog="python3 infer-cli.py",
|
35 |
description="Commandline interface for E2/F5 TTS with Advanced Batch Processing.",
|
|
|
38 |
parser.add_argument(
|
39 |
"-c",
|
40 |
"--config",
|
41 |
+
type=str,
|
42 |
default=os.path.join(files("f5_tts").joinpath("infer/examples/basic"), "basic.toml"),
|
43 |
+
help="The configuration file, default see infer/examples/basic/basic.toml",
|
44 |
)
|
45 |
+
|
46 |
+
|
47 |
+
# Note. Not to provide default value here in order to read default from config file
|
48 |
+
|
49 |
parser.add_argument(
|
50 |
"-m",
|
51 |
"--model",
|
52 |
+
type=str,
|
53 |
+
help="The model name: F5-TTS | E2-TTS",
|
54 |
+
)
|
55 |
+
parser.add_argument(
|
56 |
+
"-mc",
|
57 |
+
"--model_cfg",
|
58 |
+
type=str,
|
59 |
+
help="The path to F5-TTS model config file .yaml",
|
60 |
)
|
61 |
parser.add_argument(
|
62 |
"-p",
|
63 |
"--ckpt_file",
|
64 |
+
type=str,
|
65 |
+
help="The path to model checkpoint .pt, leave blank to use default",
|
66 |
)
|
67 |
parser.add_argument(
|
68 |
"-v",
|
69 |
"--vocab_file",
|
70 |
+
type=str,
|
71 |
+
help="The path to vocab file .txt, leave blank to use default",
|
72 |
+
)
|
73 |
+
parser.add_argument(
|
74 |
+
"-r",
|
75 |
+
"--ref_audio",
|
76 |
+
type=str,
|
77 |
+
help="The reference audio file.",
|
78 |
+
)
|
79 |
+
parser.add_argument(
|
80 |
+
"-s",
|
81 |
+
"--ref_text",
|
82 |
+
type=str,
|
83 |
+
help="The transcript/subtitle for the reference audio",
|
84 |
)
|
|
|
|
|
85 |
parser.add_argument(
|
86 |
"-t",
|
87 |
"--gen_text",
|
88 |
type=str,
|
89 |
+
help="The text to make model synthesize a speech",
|
90 |
)
|
91 |
parser.add_argument(
|
92 |
"-f",
|
93 |
"--gen_file",
|
94 |
type=str,
|
95 |
+
help="The file with text to generate, will ignore --gen_text",
|
96 |
)
|
97 |
parser.add_argument(
|
98 |
"-o",
|
99 |
"--output_dir",
|
100 |
type=str,
|
101 |
+
help="The path to output folder",
|
102 |
)
|
103 |
parser.add_argument(
|
104 |
"-w",
|
105 |
"--output_file",
|
106 |
type=str,
|
107 |
+
help="The name of output file",
|
108 |
+
)
|
109 |
+
parser.add_argument(
|
110 |
+
"--save_chunk",
|
111 |
+
action="store_true",
|
112 |
+
help="To save each audio chunks during inference",
|
113 |
)
|
114 |
parser.add_argument(
|
115 |
"--remove_silence",
|
116 |
+
action="store_true",
|
117 |
+
help="To remove long silence found in ouput",
|
118 |
)
|
|
|
119 |
parser.add_argument(
|
120 |
"--load_vocoder_from_local",
|
121 |
action="store_true",
|
122 |
+
help="To load vocoder from local dir, default to ../checkpoints/vocos-mel-24khz",
|
123 |
+
)
|
124 |
+
parser.add_argument(
|
125 |
+
"--vocoder_name",
|
126 |
+
type=str,
|
127 |
+
choices=["vocos", "bigvgan"],
|
128 |
+
help=f"Used vocoder name: vocos | bigvgan, default {mel_spec_type}",
|
129 |
+
)
|
130 |
+
parser.add_argument(
|
131 |
+
"--target_rms",
|
132 |
+
type=float,
|
133 |
+
help=f"Target output speech loudness normalization value, default {target_rms}",
|
134 |
+
)
|
135 |
+
parser.add_argument(
|
136 |
+
"--cross_fade_duration",
|
137 |
+
type=float,
|
138 |
+
help=f"Duration of cross-fade between audio segments in seconds, default {cross_fade_duration}",
|
139 |
+
)
|
140 |
+
parser.add_argument(
|
141 |
+
"--nfe_step",
|
142 |
+
type=int,
|
143 |
+
help=f"The number of function evaluation (denoising steps), default {nfe_step}",
|
144 |
+
)
|
145 |
+
parser.add_argument(
|
146 |
+
"--cfg_strength",
|
147 |
+
type=float,
|
148 |
+
help=f"Classifier-free guidance strength, default {cfg_strength}",
|
149 |
+
)
|
150 |
+
parser.add_argument(
|
151 |
+
"--sway_sampling_coef",
|
152 |
+
type=float,
|
153 |
+
help=f"Sway Sampling coefficient, default {sway_sampling_coef}",
|
154 |
)
|
155 |
parser.add_argument(
|
156 |
"--speed",
|
157 |
type=float,
|
158 |
+
help=f"The speed of the generated audio, default {speed}",
|
159 |
+
)
|
160 |
+
parser.add_argument(
|
161 |
+
"--fix_duration",
|
162 |
+
type=float,
|
163 |
+
help=f"Fix the total duration (ref and gen audios) in seconds, default {fix_duration}",
|
164 |
)
|
165 |
args = parser.parse_args()
|
166 |
|
167 |
+
|
168 |
+
# config file
|
169 |
+
|
170 |
config = tomli.load(open(args.config, "rb"))
|
171 |
|
172 |
+
|
173 |
+
# command-line interface parameters
|
174 |
+
|
175 |
+
model = args.model or config.get("model", "F5-TTS")
|
176 |
+
model_cfg = args.model_cfg or config.get("model_cfg", str(files("f5_tts").joinpath("configs/F5TTS_Base_train.yaml")))
|
177 |
+
ckpt_file = args.ckpt_file or config.get("ckpt_file", "")
|
178 |
+
vocab_file = args.vocab_file or config.get("vocab_file", "")
|
179 |
+
|
180 |
+
ref_audio = args.ref_audio or config.get("ref_audio", "infer/examples/basic/basic_ref_en.wav")
|
181 |
+
ref_text = args.ref_text or config.get("ref_text", "Some call me nature, others call me mother nature.")
|
182 |
+
gen_text = args.gen_text or config.get("gen_text", "Here we generate something just for test.")
|
183 |
+
gen_file = args.gen_file or config.get("gen_file", "")
|
184 |
+
|
185 |
+
output_dir = args.output_dir or config.get("output_dir", "tests")
|
186 |
+
output_file = args.output_file or config.get(
|
187 |
+
"output_file", f"infer_cli_{datetime.now().strftime(r'%Y%m%d_%H%M%S')}.wav"
|
188 |
+
)
|
189 |
+
|
190 |
+
save_chunk = args.save_chunk or config.get("save_chunk", False)
|
191 |
+
remove_silence = args.remove_silence or config.get("remove_silence", False)
|
192 |
+
load_vocoder_from_local = args.load_vocoder_from_local or config.get("load_vocoder_from_local", False)
|
193 |
+
|
194 |
+
vocoder_name = args.vocoder_name or config.get("vocoder_name", mel_spec_type)
|
195 |
+
target_rms = args.target_rms or config.get("target_rms", target_rms)
|
196 |
+
cross_fade_duration = args.cross_fade_duration or config.get("cross_fade_duration", cross_fade_duration)
|
197 |
+
nfe_step = args.nfe_step or config.get("nfe_step", nfe_step)
|
198 |
+
cfg_strength = args.cfg_strength or config.get("cfg_strength", cfg_strength)
|
199 |
+
sway_sampling_coef = args.sway_sampling_coef or config.get("sway_sampling_coef", sway_sampling_coef)
|
200 |
+
speed = args.speed or config.get("speed", speed)
|
201 |
+
fix_duration = args.fix_duration or config.get("fix_duration", fix_duration)
|
202 |
+
|
203 |
|
204 |
# patches for pip pkg user
|
205 |
if "infer/examples/" in ref_audio:
|
|
|
212 |
if "infer/examples/" in voice_ref_audio:
|
213 |
config["voices"][voice]["ref_audio"] = str(files("f5_tts").joinpath(f"{voice_ref_audio}"))
|
214 |
|
215 |
+
|
216 |
+
# ignore gen_text if gen_file provided
|
217 |
+
|
218 |
if gen_file:
|
219 |
gen_text = codecs.open(gen_file, "r", "utf-8").read()
|
220 |
+
|
221 |
+
|
222 |
+
# output path
|
|
|
|
|
|
|
|
|
223 |
|
224 |
wave_path = Path(output_dir) / output_file
|
225 |
# spectrogram_path = Path(output_dir) / "infer_cli_out.png"
|
226 |
+
if save_chunk:
|
227 |
+
output_chunk_dir = os.path.join(output_dir, f"{Path(output_file).stem}_chunks")
|
228 |
+
if not os.path.exists(output_chunk_dir):
|
229 |
+
os.makedirs(output_chunk_dir)
|
230 |
+
|
231 |
+
|
232 |
+
# load vocoder
|
233 |
|
|
|
|
|
234 |
if vocoder_name == "vocos":
|
235 |
vocoder_local_path = "../checkpoints/vocos-mel-24khz"
|
236 |
elif vocoder_name == "bigvgan":
|
237 |
vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
|
238 |
|
239 |
+
vocoder = load_vocoder(vocoder_name=vocoder_name, is_local=load_vocoder_from_local, local_path=vocoder_local_path)
|
240 |
+
|
241 |
|
242 |
+
# load TTS model
|
243 |
|
|
|
244 |
if model == "F5-TTS":
|
245 |
model_cls = DiT
|
246 |
+
model_cfg = OmegaConf.load(model_cfg).model.arch
|
247 |
+
if not ckpt_file: # path not specified, download from repo
|
248 |
if vocoder_name == "vocos":
|
249 |
repo_name = "F5-TTS"
|
250 |
exp_name = "F5TTS_Base"
|
|
|
258 |
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.pt"))
|
259 |
|
260 |
elif model == "E2-TTS":
|
261 |
+
assert args.model_cfg is None, "E2-TTS does not support custom model_cfg yet"
|
262 |
+
assert vocoder_name == "vocos", "E2-TTS only supports vocoder vocos yet"
|
263 |
model_cls = UNetT
|
264 |
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
265 |
+
if not ckpt_file: # path not specified, download from repo
|
266 |
repo_name = "E2-TTS"
|
267 |
exp_name = "E2TTS_Base"
|
268 |
ckpt_step = 1200000
|
269 |
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
|
270 |
# ckpt_file = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path
|
271 |
|
|
|
272 |
print(f"Using {model}...")
|
273 |
+
ema_model = load_model(model_cls, model_cfg, ckpt_file, mel_spec_type=vocoder_name, vocab_file=vocab_file)
|
274 |
+
|
275 |
|
276 |
+
# inference process
|
277 |
|
278 |
+
|
279 |
+
def main():
|
280 |
main_voice = {"ref_audio": ref_audio, "ref_text": ref_text}
|
281 |
if "voices" not in config:
|
282 |
voices = {"main": main_voice}
|
|
|
284 |
voices = config["voices"]
|
285 |
voices["main"] = main_voice
|
286 |
for voice in voices:
|
287 |
+
print("Voice:", voice)
|
288 |
+
print("ref_audio ", voices[voice]["ref_audio"])
|
289 |
voices[voice]["ref_audio"], voices[voice]["ref_text"] = preprocess_ref_audio_text(
|
290 |
voices[voice]["ref_audio"], voices[voice]["ref_text"]
|
291 |
)
|
292 |
+
print("ref_audio_", voices[voice]["ref_audio"], "\n\n")
|
|
|
|
|
293 |
|
294 |
generated_audio_segments = []
|
295 |
reg1 = r"(?=\[\w+\])"
|
296 |
+
chunks = re.split(reg1, gen_text)
|
297 |
reg2 = r"\[(\w+)\]"
|
298 |
for text in chunks:
|
299 |
if not text.strip():
|
|
|
308 |
print(f"Voice {voice} not found, using main.")
|
309 |
voice = "main"
|
310 |
text = re.sub(reg2, "", text)
|
311 |
+
ref_audio_ = voices[voice]["ref_audio"]
|
312 |
+
ref_text_ = voices[voice]["ref_text"]
|
313 |
+
gen_text_ = text.strip()
|
314 |
print(f"Voice: {voice}")
|
315 |
+
audio_segment, final_sample_rate, spectragram = infer_process(
|
316 |
+
ref_audio_,
|
317 |
+
ref_text_,
|
318 |
+
gen_text_,
|
319 |
+
ema_model,
|
320 |
+
vocoder,
|
321 |
+
mel_spec_type=vocoder_name,
|
322 |
+
target_rms=target_rms,
|
323 |
+
cross_fade_duration=cross_fade_duration,
|
324 |
+
nfe_step=nfe_step,
|
325 |
+
cfg_strength=cfg_strength,
|
326 |
+
sway_sampling_coef=sway_sampling_coef,
|
327 |
+
speed=speed,
|
328 |
+
fix_duration=fix_duration,
|
329 |
)
|
330 |
+
generated_audio_segments.append(audio_segment)
|
331 |
+
|
332 |
+
if save_chunk:
|
333 |
+
if len(gen_text_) > 200:
|
334 |
+
gen_text_ = gen_text_[:200] + " ... "
|
335 |
+
sf.write(
|
336 |
+
os.path.join(output_chunk_dir, f"{len(generated_audio_segments)-1}_{gen_text_}.wav"),
|
337 |
+
audio_segment,
|
338 |
+
final_sample_rate,
|
339 |
+
)
|
340 |
|
341 |
if generated_audio_segments:
|
342 |
final_wave = np.concatenate(generated_audio_segments)
|
|
|
352 |
print(f.name)
|
353 |
|
354 |
|
|
|
|
|
|
|
|
|
355 |
if __name__ == "__main__":
|
356 |
main()
|
src/f5_tts/infer/utils_infer.py
CHANGED
@@ -138,7 +138,11 @@ asr_pipe = None
|
|
138 |
def initialize_asr_pipeline(device: str = device, dtype=None):
|
139 |
if dtype is None:
|
140 |
dtype = (
|
141 |
-
torch.float16
|
|
|
|
|
|
|
|
|
142 |
)
|
143 |
global asr_pipe
|
144 |
asr_pipe = pipeline(
|
@@ -171,7 +175,11 @@ def transcribe(ref_audio, language=None):
|
|
171 |
def load_checkpoint(model, ckpt_path, device: str, dtype=None, use_ema=True):
|
172 |
if dtype is None:
|
173 |
dtype = (
|
174 |
-
torch.float16
|
|
|
|
|
|
|
|
|
175 |
)
|
176 |
model = model.to(dtype)
|
177 |
|
@@ -338,7 +346,7 @@ def preprocess_ref_audio_text(ref_audio_orig, ref_text, clip_short=True, show_in
|
|
338 |
else:
|
339 |
ref_text += ". "
|
340 |
|
341 |
-
print("
|
342 |
|
343 |
return ref_audio, ref_text
|
344 |
|
@@ -370,6 +378,7 @@ def infer_process(
|
|
370 |
gen_text_batches = chunk_text(gen_text, max_chars=max_chars)
|
371 |
for i, gen_text in enumerate(gen_text_batches):
|
372 |
print(f"gen_text {i}", gen_text)
|
|
|
373 |
|
374 |
show_info(f"Generating audio in {len(gen_text_batches)} batches...")
|
375 |
return infer_batch_process(
|
|
|
138 |
def initialize_asr_pipeline(device: str = device, dtype=None):
|
139 |
if dtype is None:
|
140 |
dtype = (
|
141 |
+
torch.float16
|
142 |
+
if "cuda" in device
|
143 |
+
and torch.cuda.get_device_properties(device).major >= 6
|
144 |
+
and not torch.cuda.get_device_name().endswith("[ZLUDA]")
|
145 |
+
else torch.float32
|
146 |
)
|
147 |
global asr_pipe
|
148 |
asr_pipe = pipeline(
|
|
|
175 |
def load_checkpoint(model, ckpt_path, device: str, dtype=None, use_ema=True):
|
176 |
if dtype is None:
|
177 |
dtype = (
|
178 |
+
torch.float16
|
179 |
+
if "cuda" in device
|
180 |
+
and torch.cuda.get_device_properties(device).major >= 6
|
181 |
+
and not torch.cuda.get_device_name().endswith("[ZLUDA]")
|
182 |
+
else torch.float32
|
183 |
)
|
184 |
model = model.to(dtype)
|
185 |
|
|
|
346 |
else:
|
347 |
ref_text += ". "
|
348 |
|
349 |
+
print("\nref_text ", ref_text)
|
350 |
|
351 |
return ref_audio, ref_text
|
352 |
|
|
|
378 |
gen_text_batches = chunk_text(gen_text, max_chars=max_chars)
|
379 |
for i, gen_text in enumerate(gen_text_batches):
|
380 |
print(f"gen_text {i}", gen_text)
|
381 |
+
print("\n")
|
382 |
|
383 |
show_info(f"Generating audio in {len(gen_text_batches)} batches...")
|
384 |
return infer_batch_process(
|
src/f5_tts/model/backbones/dit.py
CHANGED
@@ -105,6 +105,7 @@ class DiT(nn.Module):
|
|
105 |
text_dim=None,
|
106 |
conv_layers=0,
|
107 |
long_skip_connection=False,
|
|
|
108 |
):
|
109 |
super().__init__()
|
110 |
|
@@ -127,6 +128,16 @@ class DiT(nn.Module):
|
|
127 |
self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
|
128 |
self.proj_out = nn.Linear(dim, mel_dim)
|
129 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
def forward(
|
131 |
self,
|
132 |
x: float["b n d"], # nosied input audio # noqa: F722
|
@@ -152,7 +163,10 @@ class DiT(nn.Module):
|
|
152 |
residual = x
|
153 |
|
154 |
for block in self.transformer_blocks:
|
155 |
-
|
|
|
|
|
|
|
156 |
|
157 |
if self.long_skip_connection is not None:
|
158 |
x = self.long_skip_connection(torch.cat((x, residual), dim=-1))
|
|
|
105 |
text_dim=None,
|
106 |
conv_layers=0,
|
107 |
long_skip_connection=False,
|
108 |
+
checkpoint_activations=False,
|
109 |
):
|
110 |
super().__init__()
|
111 |
|
|
|
128 |
self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
|
129 |
self.proj_out = nn.Linear(dim, mel_dim)
|
130 |
|
131 |
+
self.checkpoint_activations = checkpoint_activations
|
132 |
+
|
133 |
+
def ckpt_wrapper(self, module):
|
134 |
+
# https://github.com/chuanyangjin/fast-DiT/blob/main/models.py
|
135 |
+
def ckpt_forward(*inputs):
|
136 |
+
outputs = module(*inputs)
|
137 |
+
return outputs
|
138 |
+
|
139 |
+
return ckpt_forward
|
140 |
+
|
141 |
def forward(
|
142 |
self,
|
143 |
x: float["b n d"], # nosied input audio # noqa: F722
|
|
|
163 |
residual = x
|
164 |
|
165 |
for block in self.transformer_blocks:
|
166 |
+
if self.checkpoint_activations:
|
167 |
+
x = torch.utils.checkpoint.checkpoint(self.ckpt_wrapper(block), x, t, mask, rope)
|
168 |
+
else:
|
169 |
+
x = block(x, t, mask=mask, rope=rope)
|
170 |
|
171 |
if self.long_skip_connection is not None:
|
172 |
x = self.long_skip_connection(torch.cat((x, residual), dim=-1))
|
src/f5_tts/model/trainer.py
CHANGED
@@ -315,7 +315,7 @@ class Trainer:
|
|
315 |
self.scheduler.step()
|
316 |
self.optimizer.zero_grad()
|
317 |
|
318 |
-
if self.is_main:
|
319 |
self.ema_model.update()
|
320 |
|
321 |
global_step += 1
|
|
|
315 |
self.scheduler.step()
|
316 |
self.optimizer.zero_grad()
|
317 |
|
318 |
+
if self.is_main and self.accelerator.sync_gradients:
|
319 |
self.ema_model.update()
|
320 |
|
321 |
global_step += 1
|
src/f5_tts/model/utils.py
CHANGED
@@ -133,16 +133,23 @@ def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
|
|
133 |
|
134 |
# convert char to pinyin
|
135 |
|
|
|
|
|
|
|
136 |
|
137 |
def convert_char_to_pinyin(text_list, polyphone=True):
|
138 |
final_text_list = []
|
139 |
-
|
140 |
-
{"“": '"', "”": '"', "‘": "'", "’": "'"}
|
141 |
-
) #
|
142 |
-
|
|
|
|
|
|
|
|
|
|
|
143 |
for text in text_list:
|
144 |
char_list = []
|
145 |
-
text = text.translate(god_knows_why_en_testset_contains_zh_quote)
|
146 |
text = text.translate(custom_trans)
|
147 |
for seg in jieba.cut(text):
|
148 |
seg_byte_len = len(bytes(seg, "UTF-8"))
|
@@ -150,22 +157,21 @@ def convert_char_to_pinyin(text_list, polyphone=True):
|
|
150 |
if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
|
151 |
char_list.append(" ")
|
152 |
char_list.extend(seg)
|
153 |
-
elif polyphone and seg_byte_len == 3 * len(seg): # if pure
|
154 |
-
|
155 |
-
for c in seg:
|
156 |
-
if c
|
157 |
char_list.append(" ")
|
158 |
-
char_list.append(
|
159 |
-
else: # if mixed
|
160 |
for c in seg:
|
161 |
if ord(c) < 256:
|
162 |
char_list.extend(c)
|
|
|
|
|
|
|
163 |
else:
|
164 |
-
|
165 |
-
char_list.append(" ")
|
166 |
-
char_list.extend(lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True))
|
167 |
-
else: # if is zh punc
|
168 |
-
char_list.append(c)
|
169 |
final_text_list.append(char_list)
|
170 |
|
171 |
return final_text_list
|
|
|
133 |
|
134 |
# convert char to pinyin
|
135 |
|
136 |
+
jieba.initialize()
|
137 |
+
print("Word segmentation module jieba initialized.\n")
|
138 |
+
|
139 |
|
140 |
def convert_char_to_pinyin(text_list, polyphone=True):
|
141 |
final_text_list = []
|
142 |
+
custom_trans = str.maketrans(
|
143 |
+
{";": ",", "“": '"', "”": '"', "‘": "'", "’": "'"}
|
144 |
+
) # add custom trans here, to address oov
|
145 |
+
|
146 |
+
def is_chinese(c):
|
147 |
+
return (
|
148 |
+
"\u3100" <= c <= "\u9fff" # common chinese characters
|
149 |
+
)
|
150 |
+
|
151 |
for text in text_list:
|
152 |
char_list = []
|
|
|
153 |
text = text.translate(custom_trans)
|
154 |
for seg in jieba.cut(text):
|
155 |
seg_byte_len = len(bytes(seg, "UTF-8"))
|
|
|
157 |
if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
|
158 |
char_list.append(" ")
|
159 |
char_list.extend(seg)
|
160 |
+
elif polyphone and seg_byte_len == 3 * len(seg): # if pure east asian characters
|
161 |
+
seg_ = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True)
|
162 |
+
for i, c in enumerate(seg):
|
163 |
+
if is_chinese(c):
|
164 |
char_list.append(" ")
|
165 |
+
char_list.append(seg_[i])
|
166 |
+
else: # if mixed characters, alphabets and symbols
|
167 |
for c in seg:
|
168 |
if ord(c) < 256:
|
169 |
char_list.extend(c)
|
170 |
+
elif is_chinese(c):
|
171 |
+
char_list.append(" ")
|
172 |
+
char_list.extend(lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True))
|
173 |
else:
|
174 |
+
char_list.append(c)
|
|
|
|
|
|
|
|
|
175 |
final_text_list.append(char_list)
|
176 |
|
177 |
return final_text_list
|
src/f5_tts/train/README.md
CHANGED
@@ -2,9 +2,9 @@
|
|
2 |
|
3 |
## Prepare Dataset
|
4 |
|
5 |
-
Example data processing scripts
|
6 |
|
7 |
-
### 1.
|
8 |
Download corresponding dataset first, and fill in the path in scripts.
|
9 |
|
10 |
```bash
|
@@ -16,6 +16,9 @@ python src/f5_tts/train/datasets/prepare_wenetspeech4tts.py
|
|
16 |
|
17 |
# Prepare the LibriTTS dataset
|
18 |
python src/f5_tts/train/datasets/prepare_libritts.py
|
|
|
|
|
|
|
19 |
```
|
20 |
|
21 |
### 2. Create custom dataset with metadata.csv
|
@@ -35,7 +38,12 @@ Once your datasets are prepared, you can start the training process.
|
|
35 |
# setup accelerate config, e.g. use multi-gpu ddp, fp16
|
36 |
# will be to: ~/.cache/huggingface/accelerate/default_config.yaml
|
37 |
accelerate config
|
38 |
-
|
|
|
|
|
|
|
|
|
|
|
39 |
```
|
40 |
|
41 |
### 2. Finetuning practice
|
@@ -43,6 +51,8 @@ Discussion board for Finetuning [#57](https://github.com/SWivid/F5-TTS/discussio
|
|
43 |
|
44 |
Gradio UI training/finetuning with `src/f5_tts/train/finetune_gradio.py` see [#143](https://github.com/SWivid/F5-TTS/discussions/143).
|
45 |
|
|
|
|
|
46 |
### 3. Wandb Logging
|
47 |
|
48 |
The `wandb/` dir will be created under path you run training/finetuning scripts.
|
|
|
2 |
|
3 |
## Prepare Dataset
|
4 |
|
5 |
+
Example data processing scripts, and you may tailor your own one along with a Dataset class in `src/f5_tts/model/dataset.py`.
|
6 |
|
7 |
+
### 1. Some specific Datasets preparing scripts
|
8 |
Download corresponding dataset first, and fill in the path in scripts.
|
9 |
|
10 |
```bash
|
|
|
16 |
|
17 |
# Prepare the LibriTTS dataset
|
18 |
python src/f5_tts/train/datasets/prepare_libritts.py
|
19 |
+
|
20 |
+
# Prepare the LJSpeech dataset
|
21 |
+
python src/f5_tts/train/datasets/prepare_ljspeech.py
|
22 |
```
|
23 |
|
24 |
### 2. Create custom dataset with metadata.csv
|
|
|
38 |
# setup accelerate config, e.g. use multi-gpu ddp, fp16
|
39 |
# will be to: ~/.cache/huggingface/accelerate/default_config.yaml
|
40 |
accelerate config
|
41 |
+
|
42 |
+
# .yaml files are under src/f5_tts/configs directory
|
43 |
+
accelerate launch src/f5_tts/train/train.py --config-name F5TTS_Base_train.yaml
|
44 |
+
|
45 |
+
# possible to overwrite accelerate and hydra config
|
46 |
+
accelerate launch --mixed_precision=fp16 src/f5_tts/train/train.py --config-name F5TTS_Small_train.yaml ++datasets.batch_size_per_gpu=19200
|
47 |
```
|
48 |
|
49 |
### 2. Finetuning practice
|
|
|
51 |
|
52 |
Gradio UI training/finetuning with `src/f5_tts/train/finetune_gradio.py` see [#143](https://github.com/SWivid/F5-TTS/discussions/143).
|
53 |
|
54 |
+
The `use_ema = True` is harmful for early-stage finetuned checkpoints (which goes just few updates, thus ema weights still dominated by pretrained ones), try turn it off and see if provide better results.
|
55 |
+
|
56 |
### 3. Wandb Logging
|
57 |
|
58 |
The `wandb/` dir will be created under path you run training/finetuning scripts.
|
src/f5_tts/train/datasets/prepare_ljspeech.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
|
4 |
+
sys.path.append(os.getcwd())
|
5 |
+
|
6 |
+
import json
|
7 |
+
from importlib.resources import files
|
8 |
+
from pathlib import Path
|
9 |
+
from tqdm import tqdm
|
10 |
+
import soundfile as sf
|
11 |
+
from datasets.arrow_writer import ArrowWriter
|
12 |
+
|
13 |
+
|
14 |
+
def main():
|
15 |
+
result = []
|
16 |
+
duration_list = []
|
17 |
+
text_vocab_set = set()
|
18 |
+
|
19 |
+
with open(meta_info, "r") as f:
|
20 |
+
lines = f.readlines()
|
21 |
+
for line in tqdm(lines):
|
22 |
+
uttr, text, norm_text = line.split("|")
|
23 |
+
norm_text = norm_text.strip()
|
24 |
+
wav_path = Path(dataset_dir) / "wavs" / f"{uttr}.wav"
|
25 |
+
duration = sf.info(wav_path).duration
|
26 |
+
if duration < 0.4 or duration > 30:
|
27 |
+
continue
|
28 |
+
result.append({"audio_path": str(wav_path), "text": norm_text, "duration": duration})
|
29 |
+
duration_list.append(duration)
|
30 |
+
text_vocab_set.update(list(norm_text))
|
31 |
+
|
32 |
+
# save preprocessed dataset to disk
|
33 |
+
if not os.path.exists(f"{save_dir}"):
|
34 |
+
os.makedirs(f"{save_dir}")
|
35 |
+
print(f"\nSaving to {save_dir} ...")
|
36 |
+
|
37 |
+
with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer:
|
38 |
+
for line in tqdm(result, desc="Writing to raw.arrow ..."):
|
39 |
+
writer.write(line)
|
40 |
+
|
41 |
+
# dup a json separately saving duration in case for DynamicBatchSampler ease
|
42 |
+
with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f:
|
43 |
+
json.dump({"duration": duration_list}, f, ensure_ascii=False)
|
44 |
+
|
45 |
+
# vocab map, i.e. tokenizer
|
46 |
+
# add alphabets and symbols (optional, if plan to ft on de/fr etc.)
|
47 |
+
with open(f"{save_dir}/vocab.txt", "w") as f:
|
48 |
+
for vocab in sorted(text_vocab_set):
|
49 |
+
f.write(vocab + "\n")
|
50 |
+
|
51 |
+
print(f"\nFor {dataset_name}, sample count: {len(result)}")
|
52 |
+
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
|
53 |
+
print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours")
|
54 |
+
|
55 |
+
|
56 |
+
if __name__ == "__main__":
|
57 |
+
tokenizer = "char" # "pinyin" | "char"
|
58 |
+
|
59 |
+
dataset_dir = "<SOME_PATH>/LJSpeech-1.1"
|
60 |
+
dataset_name = f"LJSpeech_{tokenizer}"
|
61 |
+
meta_info = os.path.join(dataset_dir, "metadata.csv")
|
62 |
+
save_dir = str(files("f5_tts").joinpath("../../")) + f"/data/{dataset_name}"
|
63 |
+
print(f"\nPrepare for {dataset_name}, will save to {save_dir}\n")
|
64 |
+
|
65 |
+
main()
|
src/f5_tts/train/train.py
CHANGED
@@ -1,100 +1,72 @@
|
|
1 |
# training script.
|
2 |
|
|
|
3 |
from importlib.resources import files
|
4 |
|
|
|
|
|
5 |
from f5_tts.model import CFM, DiT, Trainer, UNetT
|
6 |
from f5_tts.model.dataset import load_dataset
|
7 |
from f5_tts.model.utils import get_tokenizer
|
8 |
|
9 |
-
#
|
10 |
-
|
11 |
-
target_sample_rate = 24000
|
12 |
-
n_mel_channels = 100
|
13 |
-
hop_length = 256
|
14 |
-
win_length = 1024
|
15 |
-
n_fft = 1024
|
16 |
-
mel_spec_type = "vocos" # 'vocos' or 'bigvgan'
|
17 |
-
|
18 |
-
tokenizer = "pinyin" # 'pinyin', 'char', or 'custom'
|
19 |
-
tokenizer_path = None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
|
20 |
-
dataset_name = "Emilia_ZH_EN"
|
21 |
-
|
22 |
-
# -------------------------- Training Settings -------------------------- #
|
23 |
|
24 |
-
exp_name = "F5TTS_Base" # F5TTS_Base | E2TTS_Base
|
25 |
|
26 |
-
|
|
|
|
|
|
|
|
|
27 |
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
grad_accumulation_steps = 1 # note: updates = steps / grad_accumulation_steps
|
32 |
-
max_grad_norm = 1.0
|
33 |
-
|
34 |
-
epochs = 11 # use linear decay, thus epochs control the slope
|
35 |
-
num_warmup_updates = 20000 # warmup steps
|
36 |
-
save_per_updates = 50000 # save checkpoint per steps
|
37 |
-
last_per_steps = 5000 # save last checkpoint per steps
|
38 |
-
|
39 |
-
# model params
|
40 |
-
if exp_name == "F5TTS_Base":
|
41 |
-
wandb_resume_id = None
|
42 |
-
model_cls = DiT
|
43 |
-
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
|
44 |
-
elif exp_name == "E2TTS_Base":
|
45 |
-
wandb_resume_id = None
|
46 |
-
model_cls = UNetT
|
47 |
-
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
48 |
-
|
49 |
-
|
50 |
-
# ----------------------------------------------------------------------- #
|
51 |
-
|
52 |
-
|
53 |
-
def main():
|
54 |
-
if tokenizer == "custom":
|
55 |
-
tokenizer_path = tokenizer_path
|
56 |
else:
|
57 |
-
tokenizer_path =
|
58 |
vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
|
59 |
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
mel_spec_type=mel_spec_type,
|
67 |
-
)
|
68 |
|
69 |
model = CFM(
|
70 |
-
transformer=model_cls(**
|
71 |
-
mel_spec_kwargs=
|
72 |
vocab_char_map=vocab_char_map,
|
73 |
)
|
74 |
|
|
|
75 |
trainer = Trainer(
|
76 |
model,
|
77 |
-
epochs,
|
78 |
-
learning_rate,
|
79 |
-
num_warmup_updates=num_warmup_updates,
|
80 |
-
save_per_updates=save_per_updates,
|
81 |
-
checkpoint_path=str(files("f5_tts").joinpath(f"../../ckpts
|
82 |
-
batch_size=batch_size_per_gpu,
|
83 |
-
batch_size_type=batch_size_type,
|
84 |
-
max_samples=max_samples,
|
85 |
-
grad_accumulation_steps=grad_accumulation_steps,
|
86 |
-
max_grad_norm=max_grad_norm,
|
|
|
87 |
wandb_project="CFM-TTS",
|
88 |
wandb_run_name=exp_name,
|
89 |
wandb_resume_id=wandb_resume_id,
|
90 |
-
last_per_steps=last_per_steps,
|
91 |
log_samples=True,
|
|
|
92 |
mel_spec_type=mel_spec_type,
|
|
|
|
|
93 |
)
|
94 |
|
95 |
-
train_dataset = load_dataset(
|
96 |
trainer.train(
|
97 |
train_dataset,
|
|
|
98 |
resumable_with_seed=666, # seed for shuffling dataset
|
99 |
)
|
100 |
|
|
|
1 |
# training script.
|
2 |
|
3 |
+
import os
|
4 |
from importlib.resources import files
|
5 |
|
6 |
+
import hydra
|
7 |
+
|
8 |
from f5_tts.model import CFM, DiT, Trainer, UNetT
|
9 |
from f5_tts.model.dataset import load_dataset
|
10 |
from f5_tts.model.utils import get_tokenizer
|
11 |
|
12 |
+
os.chdir(str(files("f5_tts").joinpath("../.."))) # change working directory to root of project (local editable)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
|
|
14 |
|
15 |
+
@hydra.main(version_base="1.3", config_path=str(files("f5_tts").joinpath("configs")), config_name=None)
|
16 |
+
def main(cfg):
|
17 |
+
tokenizer = cfg.model.tokenizer
|
18 |
+
mel_spec_type = cfg.model.mel_spec.mel_spec_type
|
19 |
+
exp_name = f"{cfg.model.name}_{mel_spec_type}_{cfg.model.tokenizer}_{cfg.datasets.name}"
|
20 |
|
21 |
+
# set text tokenizer
|
22 |
+
if tokenizer != "custom":
|
23 |
+
tokenizer_path = cfg.datasets.name
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
else:
|
25 |
+
tokenizer_path = cfg.model.tokenizer_path
|
26 |
vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
|
27 |
|
28 |
+
# set model
|
29 |
+
if "F5TTS" in cfg.model.name:
|
30 |
+
model_cls = DiT
|
31 |
+
elif "E2TTS" in cfg.model.name:
|
32 |
+
model_cls = UNetT
|
33 |
+
wandb_resume_id = None
|
|
|
|
|
34 |
|
35 |
model = CFM(
|
36 |
+
transformer=model_cls(**cfg.model.arch, text_num_embeds=vocab_size, mel_dim=cfg.model.mel_spec.n_mel_channels),
|
37 |
+
mel_spec_kwargs=cfg.model.mel_spec,
|
38 |
vocab_char_map=vocab_char_map,
|
39 |
)
|
40 |
|
41 |
+
# init trainer
|
42 |
trainer = Trainer(
|
43 |
model,
|
44 |
+
epochs=cfg.optim.epochs,
|
45 |
+
learning_rate=cfg.optim.learning_rate,
|
46 |
+
num_warmup_updates=cfg.optim.num_warmup_updates,
|
47 |
+
save_per_updates=cfg.ckpts.save_per_updates,
|
48 |
+
checkpoint_path=str(files("f5_tts").joinpath(f"../../{cfg.ckpts.save_dir}")),
|
49 |
+
batch_size=cfg.datasets.batch_size_per_gpu,
|
50 |
+
batch_size_type=cfg.datasets.batch_size_type,
|
51 |
+
max_samples=cfg.datasets.max_samples,
|
52 |
+
grad_accumulation_steps=cfg.optim.grad_accumulation_steps,
|
53 |
+
max_grad_norm=cfg.optim.max_grad_norm,
|
54 |
+
logger=cfg.ckpts.logger,
|
55 |
wandb_project="CFM-TTS",
|
56 |
wandb_run_name=exp_name,
|
57 |
wandb_resume_id=wandb_resume_id,
|
58 |
+
last_per_steps=cfg.ckpts.last_per_steps,
|
59 |
log_samples=True,
|
60 |
+
bnb_optimizer=cfg.optim.bnb_optimizer,
|
61 |
mel_spec_type=mel_spec_type,
|
62 |
+
is_local_vocoder=cfg.model.vocoder.is_local,
|
63 |
+
local_vocoder_path=cfg.model.vocoder.local_path,
|
64 |
)
|
65 |
|
66 |
+
train_dataset = load_dataset(cfg.datasets.name, tokenizer, mel_spec_kwargs=cfg.model.mel_spec)
|
67 |
trainer.train(
|
68 |
train_dataset,
|
69 |
+
num_workers=cfg.datasets.num_workers,
|
70 |
resumable_with_seed=666, # seed for shuffling dataset
|
71 |
)
|
72 |
|