Spaces:
Sleeping
Sleeping
mrfakename
commited on
Sync from GitHub repo
Browse filesThis Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there
- src/f5_tts/train/finetune_gradio.py +132 -14
src/f5_tts/train/finetune_gradio.py
CHANGED
@@ -23,7 +23,7 @@ from datasets.arrow_writer import ArrowWriter
|
|
23 |
from safetensors.torch import save_file
|
24 |
from scipy.io import wavfile
|
25 |
from transformers import pipeline
|
26 |
-
|
27 |
from f5_tts.api import F5TTS
|
28 |
from f5_tts.model.utils import convert_char_to_pinyin
|
29 |
|
@@ -731,6 +731,97 @@ def extract_and_save_ema_model(checkpoint_path: str, new_checkpoint_path: str, s
|
|
731 |
return f"An error occurred: {e}"
|
732 |
|
733 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
734 |
def vocab_check(project_name):
|
735 |
name_project = project_name
|
736 |
path_project = os.path.join(path_data, name_project)
|
@@ -739,7 +830,7 @@ def vocab_check(project_name):
|
|
739 |
|
740 |
file_vocab = os.path.join(path_data, "Emilia_ZH_EN_pinyin/vocab.txt")
|
741 |
if not os.path.isfile(file_vocab):
|
742 |
-
return f"the file {file_vocab} not found !"
|
743 |
|
744 |
with open(file_vocab, "r", encoding="utf-8-sig") as f:
|
745 |
data = f.read()
|
@@ -747,7 +838,7 @@ def vocab_check(project_name):
|
|
747 |
vocab = set(vocab)
|
748 |
|
749 |
if not os.path.isfile(file_metadata):
|
750 |
-
return f"the file {file_metadata} not found !"
|
751 |
|
752 |
with open(file_metadata, "r", encoding="utf-8-sig") as f:
|
753 |
data = f.read()
|
@@ -765,12 +856,15 @@ def vocab_check(project_name):
|
|
765 |
if t not in vocab and t not in miss_symbols_keep:
|
766 |
miss_symbols.append(t)
|
767 |
miss_symbols_keep[t] = t
|
|
|
768 |
if miss_symbols == []:
|
|
|
769 |
info = "You can train using your language !"
|
770 |
else:
|
771 |
-
|
|
|
772 |
|
773 |
-
return info
|
774 |
|
775 |
|
776 |
def get_random_sample_prepare(project_name):
|
@@ -1009,6 +1103,38 @@ for tutorial and updates check here (https://github.com/SWivid/F5-TTS/discussion
|
|
1009 |
outputs=[random_text_transcribe, random_audio_transcribe],
|
1010 |
)
|
1011 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1012 |
with gr.TabItem("prepare Data"):
|
1013 |
gr.Markdown(
|
1014 |
"""```plaintext
|
@@ -1030,7 +1156,7 @@ for tutorial and updates check here (https://github.com/SWivid/F5-TTS/discussion
|
|
1030 |
|
1031 |
```"""
|
1032 |
)
|
1033 |
-
ch_tokenizern = gr.Checkbox(label="create vocabulary
|
1034 |
bt_prepare = bt_create = gr.Button("prepare")
|
1035 |
txt_info_prepare = gr.Text(label="info", value="")
|
1036 |
txt_vocab_prepare = gr.Text(label="vocab", value="")
|
@@ -1048,14 +1174,6 @@ for tutorial and updates check here (https://github.com/SWivid/F5-TTS/discussion
|
|
1048 |
fn=get_random_sample_prepare, inputs=[cm_project], outputs=[random_text_prepare, random_audio_prepare]
|
1049 |
)
|
1050 |
|
1051 |
-
with gr.TabItem("vocab check"):
|
1052 |
-
gr.Markdown("""```plaintext
|
1053 |
-
check the vocabulary for fine-tuning Emilia_ZH_EN to ensure all symbols are included. for finetune new language
|
1054 |
-
```""")
|
1055 |
-
check_button = gr.Button("check vocab")
|
1056 |
-
txt_info_check = gr.Text(label="info", value="")
|
1057 |
-
check_button.click(fn=vocab_check, inputs=[cm_project], outputs=[txt_info_check])
|
1058 |
-
|
1059 |
with gr.TabItem("train Data"):
|
1060 |
gr.Markdown("""```plaintext
|
1061 |
The auto-setting is still experimental. Please make sure that the epochs , save per updates , and last per steps are set correctly, or change them manually as needed.
|
|
|
23 |
from safetensors.torch import save_file
|
24 |
from scipy.io import wavfile
|
25 |
from transformers import pipeline
|
26 |
+
from cached_path import cached_path
|
27 |
from f5_tts.api import F5TTS
|
28 |
from f5_tts.model.utils import convert_char_to_pinyin
|
29 |
|
|
|
731 |
return f"An error occurred: {e}"
|
732 |
|
733 |
|
734 |
+
def expand_model_embeddings(ckpt_path, new_ckpt_path, num_new_tokens=42):
|
735 |
+
seed = 666
|
736 |
+
random.seed(seed)
|
737 |
+
os.environ["PYTHONHASHSEED"] = str(seed)
|
738 |
+
torch.manual_seed(seed)
|
739 |
+
torch.cuda.manual_seed(seed)
|
740 |
+
torch.cuda.manual_seed_all(seed)
|
741 |
+
torch.backends.cudnn.deterministic = True
|
742 |
+
torch.backends.cudnn.benchmark = False
|
743 |
+
|
744 |
+
ckpt = torch.load(ckpt_path, map_location="cpu")
|
745 |
+
|
746 |
+
ema_sd = ckpt.get("ema_model_state_dict", {})
|
747 |
+
embed_key_ema = "ema_model.transformer.text_embed.text_embed.weight"
|
748 |
+
old_embed_ema = ema_sd[embed_key_ema]
|
749 |
+
|
750 |
+
vocab_old = old_embed_ema.size(0)
|
751 |
+
embed_dim = old_embed_ema.size(1)
|
752 |
+
vocab_new = vocab_old + num_new_tokens
|
753 |
+
|
754 |
+
def expand_embeddings(old_embeddings):
|
755 |
+
new_embeddings = torch.zeros((vocab_new, embed_dim))
|
756 |
+
new_embeddings[:vocab_old] = old_embeddings
|
757 |
+
new_embeddings[vocab_old:] = torch.randn((num_new_tokens, embed_dim))
|
758 |
+
return new_embeddings
|
759 |
+
|
760 |
+
ema_sd[embed_key_ema] = expand_embeddings(ema_sd[embed_key_ema])
|
761 |
+
|
762 |
+
torch.save(ckpt, new_ckpt_path)
|
763 |
+
|
764 |
+
return vocab_new
|
765 |
+
|
766 |
+
|
767 |
+
def vocab_count(text):
|
768 |
+
return str(len(text.split(",")))
|
769 |
+
|
770 |
+
|
771 |
+
def vocab_extend(project_name, symbols, model_type):
|
772 |
+
if symbols == "":
|
773 |
+
return "Symbols empty!"
|
774 |
+
|
775 |
+
name_project = project_name
|
776 |
+
path_project = os.path.join(path_data, name_project)
|
777 |
+
file_vocab_project = os.path.join(path_project, "vocab.txt")
|
778 |
+
|
779 |
+
file_vocab = os.path.join(path_data, "Emilia_ZH_EN_pinyin/vocab.txt")
|
780 |
+
if not os.path.isfile(file_vocab):
|
781 |
+
return f"the file {file_vocab} not found !"
|
782 |
+
|
783 |
+
symbols = symbols.split(",")
|
784 |
+
if symbols == []:
|
785 |
+
return "Symbols to extend not found."
|
786 |
+
|
787 |
+
with open(file_vocab, "r", encoding="utf-8-sig") as f:
|
788 |
+
data = f.read()
|
789 |
+
vocab = data.split("\n")
|
790 |
+
vocab_check = set(vocab)
|
791 |
+
|
792 |
+
miss_symbols = []
|
793 |
+
for item in symbols:
|
794 |
+
item = item.replace(" ", "")
|
795 |
+
if item in vocab_check:
|
796 |
+
continue
|
797 |
+
miss_symbols.append(item)
|
798 |
+
|
799 |
+
if miss_symbols == []:
|
800 |
+
return "Symbols are okay no need to extend."
|
801 |
+
|
802 |
+
size_vocab = len(vocab)
|
803 |
+
|
804 |
+
for item in miss_symbols:
|
805 |
+
vocab.append(item)
|
806 |
+
|
807 |
+
with open(file_vocab_project, "w", encoding="utf-8-sig") as f:
|
808 |
+
f.write("\n".join(vocab))
|
809 |
+
|
810 |
+
if model_type == "F5-TTS":
|
811 |
+
ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt"))
|
812 |
+
else:
|
813 |
+
ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt"))
|
814 |
+
|
815 |
+
new_ckpt_path = os.path.join(path_project_ckpts, name_project)
|
816 |
+
os.makedirs(new_ckpt_path, exist_ok=True)
|
817 |
+
new_ckpt_file = os.path.join(new_ckpt_path, "model_1200000.pt")
|
818 |
+
|
819 |
+
size = expand_model_embeddings(ckpt_path, new_ckpt_file, num_new_tokens=len(miss_symbols))
|
820 |
+
|
821 |
+
vocab_new = "\n".join(miss_symbols)
|
822 |
+
return f"vocab old size : {size_vocab}\nvocab new size : {size}\nvocab add : {len(miss_symbols)}\nnew symbols :\n{vocab_new}"
|
823 |
+
|
824 |
+
|
825 |
def vocab_check(project_name):
|
826 |
name_project = project_name
|
827 |
path_project = os.path.join(path_data, name_project)
|
|
|
830 |
|
831 |
file_vocab = os.path.join(path_data, "Emilia_ZH_EN_pinyin/vocab.txt")
|
832 |
if not os.path.isfile(file_vocab):
|
833 |
+
return f"the file {file_vocab} not found !", ""
|
834 |
|
835 |
with open(file_vocab, "r", encoding="utf-8-sig") as f:
|
836 |
data = f.read()
|
|
|
838 |
vocab = set(vocab)
|
839 |
|
840 |
if not os.path.isfile(file_metadata):
|
841 |
+
return f"the file {file_metadata} not found !", ""
|
842 |
|
843 |
with open(file_metadata, "r", encoding="utf-8-sig") as f:
|
844 |
data = f.read()
|
|
|
856 |
if t not in vocab and t not in miss_symbols_keep:
|
857 |
miss_symbols.append(t)
|
858 |
miss_symbols_keep[t] = t
|
859 |
+
|
860 |
if miss_symbols == []:
|
861 |
+
vocab_miss = ""
|
862 |
info = "You can train using your language !"
|
863 |
else:
|
864 |
+
vocab_miss = ",".join(miss_symbols)
|
865 |
+
info = f"The following symbols are missing in your language {len(miss_symbols)}\n\n"
|
866 |
|
867 |
+
return info, vocab_miss
|
868 |
|
869 |
|
870 |
def get_random_sample_prepare(project_name):
|
|
|
1103 |
outputs=[random_text_transcribe, random_audio_transcribe],
|
1104 |
)
|
1105 |
|
1106 |
+
with gr.TabItem("vocab check"):
|
1107 |
+
gr.Markdown("""```plaintext
|
1108 |
+
check the vocabulary for fine-tuning Emilia_ZH_EN to ensure all symbols are included. for finetune new language
|
1109 |
+
```""")
|
1110 |
+
|
1111 |
+
check_button = gr.Button("check vocab")
|
1112 |
+
txt_info_check = gr.Text(label="info", value="")
|
1113 |
+
|
1114 |
+
gr.Markdown("""```plaintext
|
1115 |
+
Using the extended model, you can fine-tune to a new language that is missing symbols in the vocab , this create a new model with a new vocabulary size and save it in your ckpts/project folder.
|
1116 |
+
```""")
|
1117 |
+
|
1118 |
+
exp_name_extend = gr.Radio(label="Model", choices=["F5-TTS", "E2-TTS"], value="F5-TTS")
|
1119 |
+
|
1120 |
+
with gr.Row():
|
1121 |
+
txt_extend = gr.Textbox(
|
1122 |
+
label="Symbols",
|
1123 |
+
value="",
|
1124 |
+
placeholder="To add new symbols, make sure to use ',' for each symbol",
|
1125 |
+
scale=6,
|
1126 |
+
)
|
1127 |
+
txt_count_symbol = gr.Textbox(label="new size vocab", value="", scale=1)
|
1128 |
+
|
1129 |
+
extend_button = gr.Button("Extended")
|
1130 |
+
txt_info_extend = gr.Text(label="info", value="")
|
1131 |
+
|
1132 |
+
txt_extend.change(vocab_count, inputs=[txt_extend], outputs=[txt_count_symbol])
|
1133 |
+
check_button.click(fn=vocab_check, inputs=[cm_project], outputs=[txt_info_check, txt_extend])
|
1134 |
+
extend_button.click(
|
1135 |
+
fn=vocab_extend, inputs=[cm_project, txt_extend, exp_name_extend], outputs=[txt_info_extend]
|
1136 |
+
)
|
1137 |
+
|
1138 |
with gr.TabItem("prepare Data"):
|
1139 |
gr.Markdown(
|
1140 |
"""```plaintext
|
|
|
1156 |
|
1157 |
```"""
|
1158 |
)
|
1159 |
+
ch_tokenizern = gr.Checkbox(label="create vocabulary", value=False, visible=False)
|
1160 |
bt_prepare = bt_create = gr.Button("prepare")
|
1161 |
txt_info_prepare = gr.Text(label="info", value="")
|
1162 |
txt_vocab_prepare = gr.Text(label="vocab", value="")
|
|
|
1174 |
fn=get_random_sample_prepare, inputs=[cm_project], outputs=[random_text_prepare, random_audio_prepare]
|
1175 |
)
|
1176 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1177 |
with gr.TabItem("train Data"):
|
1178 |
gr.Markdown("""```plaintext
|
1179 |
The auto-setting is still experimental. Please make sure that the epochs , save per updates , and last per steps are set correctly, or change them manually as needed.
|