mrfakename commited on
Commit
9c54d62
·
verified ·
1 Parent(s): d37849f

Sync from GitHub repo

Browse files

This Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there

Files changed (2) hide show
  1. api.py +117 -0
  2. model/utils_infer.py +54 -20
api.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import soundfile as sf
2
+ import torch
3
+ import tqdm
4
+ from cached_path import cached_path
5
+
6
+ from model import DiT, UNetT
7
+ from model.utils import save_spectrogram
8
+
9
+ from model.utils_infer import load_vocoder, load_model, infer_process, remove_silence_for_generated_wav
10
+
11
+
12
+ class F5TTS:
13
+ def __init__(
14
+ self,
15
+ model_type="F5-TTS",
16
+ ckpt_file="",
17
+ vocab_file="",
18
+ ode_method="euler",
19
+ use_ema=True,
20
+ local_path=None,
21
+ device=None,
22
+ ):
23
+ # Initialize parameters
24
+ self.final_wave = None
25
+ self.target_sample_rate = 24000
26
+ self.n_mel_channels = 100
27
+ self.hop_length = 256
28
+ self.target_rms = 0.1
29
+
30
+ # Set device
31
+ self.device = device or (
32
+ "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
33
+ )
34
+
35
+ # Load models
36
+ self.load_vecoder_model(local_path)
37
+ self.load_ema_model(model_type, ckpt_file, vocab_file, ode_method, use_ema)
38
+
39
+ def load_vecoder_model(self, local_path):
40
+ self.vocos = load_vocoder(local_path is not None, local_path, self.device)
41
+
42
+ def load_ema_model(self, model_type, ckpt_file, vocab_file, ode_method, use_ema):
43
+ if model_type == "F5-TTS":
44
+ if not ckpt_file:
45
+ ckpt_file = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors"))
46
+ model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
47
+ model_cls = DiT
48
+ elif model_type == "E2-TTS":
49
+ if not ckpt_file:
50
+ ckpt_file = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors"))
51
+ model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
52
+ model_cls = UNetT
53
+ else:
54
+ raise ValueError(f"Unknown model type: {model_type}")
55
+
56
+ self.ema_model = load_model(model_cls, model_cfg, ckpt_file, vocab_file, ode_method, use_ema, self.device)
57
+
58
+ def export_wav(self, wav, file_wave, remove_silence=False):
59
+ if remove_silence:
60
+ remove_silence_for_generated_wav(file_wave)
61
+
62
+ sf.write(file_wave, wav, self.target_sample_rate)
63
+
64
+ def export_spectrogram(self, spect, file_spect):
65
+ save_spectrogram(spect, file_spect)
66
+
67
+ def infer(
68
+ self,
69
+ ref_file,
70
+ ref_text,
71
+ gen_text,
72
+ sway_sampling_coef=-1,
73
+ cfg_strength=2,
74
+ nfe_step=32,
75
+ speed=1.0,
76
+ fix_duration=None,
77
+ remove_silence=False,
78
+ file_wave=None,
79
+ file_spect=None,
80
+ cross_fade_duration=0.15,
81
+ show_info=print,
82
+ progress=tqdm,
83
+ ):
84
+ wav, sr, spect = infer_process(
85
+ ref_file,
86
+ ref_text,
87
+ gen_text,
88
+ self.ema_model,
89
+ cross_fade_duration,
90
+ speed,
91
+ show_info,
92
+ progress,
93
+ nfe_step,
94
+ cfg_strength,
95
+ sway_sampling_coef,
96
+ fix_duration,
97
+ )
98
+
99
+ if file_wave is not None:
100
+ self.export_wav(wav, file_wave, remove_silence)
101
+
102
+ if file_spect is not None:
103
+ self.export_spectrogram(spect, file_spect)
104
+
105
+ return wav, sr, spect
106
+
107
+
108
+ if __name__ == "__main__":
109
+ f5tts = F5TTS()
110
+
111
+ wav, sr, spect = f5tts.infer(
112
+ ref_file="tests/ref_audio/test_en_1_ref_short.wav",
113
+ ref_text="some call me nature, others call me mother nature.",
114
+ gen_text="""I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences.""",
115
+ file_wave="tests/out.wav",
116
+ file_spect="tests/out.png",
117
+ )
model/utils_infer.py CHANGED
@@ -38,12 +38,12 @@ target_sample_rate = 24000
38
  n_mel_channels = 100
39
  hop_length = 256
40
  target_rms = 0.1
41
- nfe_step = 32 # 16, 32
42
- cfg_strength = 2.0
43
- ode_method = "euler"
44
- sway_sampling_coef = -1.0
45
- speed = 1.0
46
- fix_duration = None
47
 
48
  # -----------------------------------------
49
 
@@ -84,7 +84,7 @@ def chunk_text(text, max_chars=135):
84
  # load vocoder
85
 
86
 
87
- def load_vocoder(is_local=False, local_path=""):
88
  if is_local:
89
  print(f"Load vocos from local path {local_path}")
90
  vocos = Vocos.from_hparams(f"{local_path}/config.yaml")
@@ -100,14 +100,14 @@ def load_vocoder(is_local=False, local_path=""):
100
  # load model for inference
101
 
102
 
103
- def load_model(model_cls, model_cfg, ckpt_path, vocab_file=""):
104
  if vocab_file == "":
105
  vocab_file = "Emilia_ZH_EN"
106
  tokenizer = "pinyin"
107
  else:
108
  tokenizer = "custom"
109
 
110
- print("\nvocab : ", vocab_file, tokenizer)
111
  print("tokenizer : ", tokenizer)
112
  print("model : ", ckpt_path, "\n")
113
 
@@ -125,7 +125,7 @@ def load_model(model_cls, model_cfg, ckpt_path, vocab_file=""):
125
  vocab_char_map=vocab_char_map,
126
  ).to(device)
127
 
128
- model = load_checkpoint(model, ckpt_path, device, use_ema=True)
129
 
130
  return model
131
 
@@ -178,7 +178,18 @@ def preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=print):
178
 
179
 
180
  def infer_process(
181
- ref_audio, ref_text, gen_text, model_obj, cross_fade_duration=0.15, speed=speed, show_info=print, progress=tqdm
 
 
 
 
 
 
 
 
 
 
 
182
  ):
183
  # Split the input text into batches
184
  audio, sr = torchaudio.load(ref_audio)
@@ -188,14 +199,36 @@ def infer_process(
188
  print(f"gen_text {i}", gen_text)
189
 
190
  show_info(f"Generating audio in {len(gen_text_batches)} batches...")
191
- return infer_batch_process((audio, sr), ref_text, gen_text_batches, model_obj, cross_fade_duration, speed, progress)
 
 
 
 
 
 
 
 
 
 
 
 
192
 
193
 
194
  # infer batches
195
 
196
 
197
  def infer_batch_process(
198
- ref_audio, ref_text, gen_text_batches, model_obj, cross_fade_duration=0.15, speed=1, progress=tqdm
 
 
 
 
 
 
 
 
 
 
199
  ):
200
  audio, sr = ref_audio
201
  if audio.shape[0] > 1:
@@ -219,11 +252,14 @@ def infer_batch_process(
219
  text_list = [ref_text + gen_text]
220
  final_text_list = convert_char_to_pinyin(text_list)
221
 
222
- # Calculate duration
223
- ref_audio_len = audio.shape[-1] // hop_length
224
- ref_text_len = len(ref_text.encode("utf-8"))
225
- gen_text_len = len(gen_text.encode("utf-8"))
226
- duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
 
 
 
227
 
228
  # inference
229
  with torch.inference_mode():
@@ -293,8 +329,6 @@ def infer_batch_process(
293
 
294
 
295
  # remove silence from generated wav
296
-
297
-
298
  def remove_silence_for_generated_wav(filename):
299
  aseg = AudioSegment.from_file(filename)
300
  non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500)
 
38
  n_mel_channels = 100
39
  hop_length = 256
40
  target_rms = 0.1
41
+ # nfe_step = 32 # 16, 32
42
+ # cfg_strength = 2.0
43
+ # ode_method = "euler"
44
+ # sway_sampling_coef = -1.0
45
+ # speed = 1.0
46
+ # fix_duration = None
47
 
48
  # -----------------------------------------
49
 
 
84
  # load vocoder
85
 
86
 
87
+ def load_vocoder(is_local=False, local_path="", device=device):
88
  if is_local:
89
  print(f"Load vocos from local path {local_path}")
90
  vocos = Vocos.from_hparams(f"{local_path}/config.yaml")
 
100
  # load model for inference
101
 
102
 
103
+ def load_model(model_cls, model_cfg, ckpt_path, vocab_file="", ode_method="euler", use_ema=True, device=device):
104
  if vocab_file == "":
105
  vocab_file = "Emilia_ZH_EN"
106
  tokenizer = "pinyin"
107
  else:
108
  tokenizer = "custom"
109
 
110
+ print("\nvocab : ", vocab_file)
111
  print("tokenizer : ", tokenizer)
112
  print("model : ", ckpt_path, "\n")
113
 
 
125
  vocab_char_map=vocab_char_map,
126
  ).to(device)
127
 
128
+ model = load_checkpoint(model, ckpt_path, device, use_ema=use_ema)
129
 
130
  return model
131
 
 
178
 
179
 
180
  def infer_process(
181
+ ref_audio,
182
+ ref_text,
183
+ gen_text,
184
+ model_obj,
185
+ cross_fade_duration=0.15,
186
+ speed=1.0,
187
+ show_info=print,
188
+ progress=tqdm,
189
+ nfe_step=32,
190
+ cfg_strength=2,
191
+ sway_sampling_coef=-1,
192
+ fix_duration=None,
193
  ):
194
  # Split the input text into batches
195
  audio, sr = torchaudio.load(ref_audio)
 
199
  print(f"gen_text {i}", gen_text)
200
 
201
  show_info(f"Generating audio in {len(gen_text_batches)} batches...")
202
+ return infer_batch_process(
203
+ (audio, sr),
204
+ ref_text,
205
+ gen_text_batches,
206
+ model_obj,
207
+ cross_fade_duration,
208
+ speed,
209
+ progress,
210
+ nfe_step,
211
+ cfg_strength,
212
+ sway_sampling_coef,
213
+ fix_duration,
214
+ )
215
 
216
 
217
  # infer batches
218
 
219
 
220
  def infer_batch_process(
221
+ ref_audio,
222
+ ref_text,
223
+ gen_text_batches,
224
+ model_obj,
225
+ cross_fade_duration=0.15,
226
+ speed=1,
227
+ progress=tqdm,
228
+ nfe_step=32,
229
+ cfg_strength=2.0,
230
+ sway_sampling_coef=-1,
231
+ fix_duration=None,
232
  ):
233
  audio, sr = ref_audio
234
  if audio.shape[0] > 1:
 
252
  text_list = [ref_text + gen_text]
253
  final_text_list = convert_char_to_pinyin(text_list)
254
 
255
+ if fix_duration is not None:
256
+ duration = int(fix_duration * target_sample_rate / hop_length)
257
+ else:
258
+ # Calculate duration
259
+ ref_audio_len = audio.shape[-1] // hop_length
260
+ ref_text_len = len(ref_text.encode("utf-8"))
261
+ gen_text_len = len(gen_text.encode("utf-8"))
262
+ duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
263
 
264
  # inference
265
  with torch.inference_mode():
 
329
 
330
 
331
  # remove silence from generated wav
 
 
332
  def remove_silence_for_generated_wav(filename):
333
  aseg = AudioSegment.from_file(filename)
334
  non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500)