mrfakename
commited on
Sync from GitHub repo
Browse filesThis Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there
- README_REPO.md +3 -0
- inference-cli.py +51 -15
README_REPO.md
CHANGED
@@ -86,6 +86,9 @@ Currently support 30s for a single generation, which is the **TOTAL** length of
|
|
86 |
|
87 |
Either you can specify everything in `inference-cli.toml` or override with flags. Leave `--ref_text ""` will have ASR model transcribe the reference audio automatically (use extra GPU memory). If encounter network error, consider use local ckpt, just set `ckpt_path` in `inference-cli.py`
|
88 |
|
|
|
|
|
|
|
89 |
```bash
|
90 |
python inference-cli.py \
|
91 |
--model "F5-TTS" \
|
|
|
86 |
|
87 |
Either you can specify everything in `inference-cli.toml` or override with flags. Leave `--ref_text ""` will have ASR model transcribe the reference audio automatically (use extra GPU memory). If encounter network error, consider use local ckpt, just set `ckpt_path` in `inference-cli.py`
|
88 |
|
89 |
+
for change model use --ckpt_file to specify the model you want to load,
|
90 |
+
for change vocab.txt use --vocab_file to provide your vocab.txt file.
|
91 |
+
|
92 |
```bash
|
93 |
python inference-cli.py \
|
94 |
--model "F5-TTS" \
|
inference-cli.py
CHANGED
@@ -36,6 +36,16 @@ parser.add_argument(
|
|
36 |
"--model",
|
37 |
help="F5-TTS | E2-TTS",
|
38 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
parser.add_argument(
|
40 |
"-r",
|
41 |
"--ref_audio",
|
@@ -88,6 +98,8 @@ if gen_file:
|
|
88 |
gen_text = codecs.open(gen_file, "r", "utf-8").read()
|
89 |
output_dir = args.output_dir if args.output_dir else config["output_dir"]
|
90 |
model = args.model if args.model else config["model"]
|
|
|
|
|
91 |
remove_silence = args.remove_silence if args.remove_silence else config["remove_silence"]
|
92 |
wave_path = Path(output_dir)/"out.wav"
|
93 |
spectrogram_path = Path(output_dir)/"out.png"
|
@@ -125,11 +137,19 @@ speed = 1.0
|
|
125 |
# fix_duration = 27 # None or float (duration in seconds)
|
126 |
fix_duration = None
|
127 |
|
128 |
-
def load_model(
|
129 |
-
|
130 |
-
if
|
131 |
-
|
132 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
model = CFM(
|
134 |
transformer=model_cls(
|
135 |
**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels
|
@@ -149,14 +169,12 @@ def load_model(repo_name, exp_name, model_cls, model_cfg, ckpt_step):
|
|
149 |
|
150 |
return model
|
151 |
|
152 |
-
|
153 |
# load models
|
154 |
F5TTS_model_cfg = dict(
|
155 |
dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4
|
156 |
)
|
157 |
E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
158 |
|
159 |
-
|
160 |
def chunk_text(text, max_chars=135):
|
161 |
"""
|
162 |
Splits the input text into chunks, each with a maximum number of characters.
|
@@ -184,12 +202,29 @@ def chunk_text(text, max_chars=135):
|
|
184 |
|
185 |
return chunks
|
186 |
|
|
|
|
|
|
|
187 |
|
188 |
-
def infer_batch(ref_audio, ref_text, gen_text_batches, model, remove_silence, cross_fade_duration=0.15):
|
189 |
if model == "F5-TTS":
|
190 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
191 |
elif model == "E2-TTS":
|
192 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
193 |
|
194 |
audio, sr = ref_audio
|
195 |
if audio.shape[0] > 1:
|
@@ -325,7 +360,7 @@ def process_voice(ref_audio_orig, ref_text):
|
|
325 |
print("Using custom reference text...")
|
326 |
return ref_audio, ref_text
|
327 |
|
328 |
-
def infer(ref_audio, ref_text, gen_text, model, remove_silence, cross_fade_duration=0.15):
|
329 |
print(gen_text)
|
330 |
# Add the functionality to ensure it ends with ". "
|
331 |
if not ref_text.endswith(". ") and not ref_text.endswith("。"):
|
@@ -343,10 +378,10 @@ def infer(ref_audio, ref_text, gen_text, model, remove_silence, cross_fade_durat
|
|
343 |
print(f'gen_text {i}', gen_text)
|
344 |
|
345 |
print(f"Generating audio using {model} in {len(gen_text_batches)} batches, loading models...")
|
346 |
-
return infer_batch((audio, sr), ref_text, gen_text_batches, model, remove_silence, cross_fade_duration)
|
347 |
|
348 |
|
349 |
-
def process(ref_audio, ref_text, text_gen, model, remove_silence):
|
350 |
main_voice = {"ref_audio":ref_audio, "ref_text":ref_text}
|
351 |
if "voices" not in config:
|
352 |
voices = {"main": main_voice}
|
@@ -371,7 +406,7 @@ def process(ref_audio, ref_text, text_gen, model, remove_silence):
|
|
371 |
ref_audio = voices[voice]['ref_audio']
|
372 |
ref_text = voices[voice]['ref_text']
|
373 |
print(f"Voice: {voice}")
|
374 |
-
audio, spectragram = infer(ref_audio, ref_text, gen_text, model, remove_silence)
|
375 |
generated_audio_segments.append(audio)
|
376 |
|
377 |
if generated_audio_segments:
|
@@ -389,4 +424,5 @@ def process(ref_audio, ref_text, text_gen, model, remove_silence):
|
|
389 |
aseg.export(f.name, format="wav")
|
390 |
print(f.name)
|
391 |
|
392 |
-
|
|
|
|
36 |
"--model",
|
37 |
help="F5-TTS | E2-TTS",
|
38 |
)
|
39 |
+
parser.add_argument(
|
40 |
+
"-p",
|
41 |
+
"--ckpt_file",
|
42 |
+
help="The Checkpoint .pt",
|
43 |
+
)
|
44 |
+
parser.add_argument(
|
45 |
+
"-v",
|
46 |
+
"--vocab_file",
|
47 |
+
help="The vocab .txt",
|
48 |
+
)
|
49 |
parser.add_argument(
|
50 |
"-r",
|
51 |
"--ref_audio",
|
|
|
98 |
gen_text = codecs.open(gen_file, "r", "utf-8").read()
|
99 |
output_dir = args.output_dir if args.output_dir else config["output_dir"]
|
100 |
model = args.model if args.model else config["model"]
|
101 |
+
ckpt_file = args.ckpt_file if args.ckpt_file else ""
|
102 |
+
vocab_file = args.vocab_file if args.vocab_file else ""
|
103 |
remove_silence = args.remove_silence if args.remove_silence else config["remove_silence"]
|
104 |
wave_path = Path(output_dir)/"out.wav"
|
105 |
spectrogram_path = Path(output_dir)/"out.png"
|
|
|
137 |
# fix_duration = 27 # None or float (duration in seconds)
|
138 |
fix_duration = None
|
139 |
|
140 |
+
def load_model(model_cls, model_cfg, ckpt_path,file_vocab):
|
141 |
+
|
142 |
+
if file_vocab=="":
|
143 |
+
file_vocab="Emilia_ZH_EN"
|
144 |
+
tokenizer="pinyin"
|
145 |
+
else:
|
146 |
+
tokenizer="custom"
|
147 |
+
|
148 |
+
print("\nvocab : ",vocab_file,tokenizer)
|
149 |
+
print("tokenizer : ",tokenizer)
|
150 |
+
print("model : ",ckpt_path,"\n")
|
151 |
+
|
152 |
+
vocab_char_map, vocab_size = get_tokenizer(file_vocab, tokenizer)
|
153 |
model = CFM(
|
154 |
transformer=model_cls(
|
155 |
**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels
|
|
|
169 |
|
170 |
return model
|
171 |
|
|
|
172 |
# load models
|
173 |
F5TTS_model_cfg = dict(
|
174 |
dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4
|
175 |
)
|
176 |
E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
177 |
|
|
|
178 |
def chunk_text(text, max_chars=135):
|
179 |
"""
|
180 |
Splits the input text into chunks, each with a maximum number of characters.
|
|
|
202 |
|
203 |
return chunks
|
204 |
|
205 |
+
#ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors
|
206 |
+
#if not Path(ckpt_path).exists():
|
207 |
+
#ckpt_path = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
|
208 |
|
209 |
+
def infer_batch(ref_audio, ref_text, gen_text_batches, model,ckpt_file,file_vocab, remove_silence, cross_fade_duration=0.15):
|
210 |
if model == "F5-TTS":
|
211 |
+
|
212 |
+
if ckpt_file == "":
|
213 |
+
repo_name= "F5-TTS"
|
214 |
+
exp_name = "F5TTS_Base"
|
215 |
+
ckpt_step= 1200000
|
216 |
+
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
|
217 |
+
|
218 |
+
ema_model = load_model(DiT, F5TTS_model_cfg, ckpt_file,file_vocab)
|
219 |
+
|
220 |
elif model == "E2-TTS":
|
221 |
+
if ckpt_file == "":
|
222 |
+
repo_name= "E2-TTS"
|
223 |
+
exp_name = "E2TTS_Base"
|
224 |
+
ckpt_step= 1200000
|
225 |
+
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
|
226 |
+
|
227 |
+
ema_model = load_model(UNetT, E2TTS_model_cfg, ckpt_file,file_vocab)
|
228 |
|
229 |
audio, sr = ref_audio
|
230 |
if audio.shape[0] > 1:
|
|
|
360 |
print("Using custom reference text...")
|
361 |
return ref_audio, ref_text
|
362 |
|
363 |
+
def infer(ref_audio, ref_text, gen_text, model,ckpt_file,file_vocab, remove_silence, cross_fade_duration=0.15):
|
364 |
print(gen_text)
|
365 |
# Add the functionality to ensure it ends with ". "
|
366 |
if not ref_text.endswith(". ") and not ref_text.endswith("。"):
|
|
|
378 |
print(f'gen_text {i}', gen_text)
|
379 |
|
380 |
print(f"Generating audio using {model} in {len(gen_text_batches)} batches, loading models...")
|
381 |
+
return infer_batch((audio, sr), ref_text, gen_text_batches, model,ckpt_file,file_vocab, remove_silence, cross_fade_duration)
|
382 |
|
383 |
|
384 |
+
def process(ref_audio, ref_text, text_gen, model,ckpt_file,file_vocab, remove_silence):
|
385 |
main_voice = {"ref_audio":ref_audio, "ref_text":ref_text}
|
386 |
if "voices" not in config:
|
387 |
voices = {"main": main_voice}
|
|
|
406 |
ref_audio = voices[voice]['ref_audio']
|
407 |
ref_text = voices[voice]['ref_text']
|
408 |
print(f"Voice: {voice}")
|
409 |
+
audio, spectragram = infer(ref_audio, ref_text, gen_text, model,ckpt_file,file_vocab, remove_silence)
|
410 |
generated_audio_segments.append(audio)
|
411 |
|
412 |
if generated_audio_segments:
|
|
|
424 |
aseg.export(f.name, format="wav")
|
425 |
print(f.name)
|
426 |
|
427 |
+
|
428 |
+
process(ref_audio, ref_text, gen_text, model,ckpt_file,vocab_file, remove_silence)
|