mrfakename commited on
Commit
b624c42
·
verified ·
1 Parent(s): 118c154

Sync from GitHub repo

Browse files

This Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there

api.py CHANGED
@@ -69,6 +69,10 @@ class F5TTS:
69
  ref_file,
70
  ref_text,
71
  gen_text,
 
 
 
 
72
  sway_sampling_coef=-1,
73
  cfg_strength=2,
74
  nfe_step=32,
@@ -77,23 +81,21 @@ class F5TTS:
77
  remove_silence=False,
78
  file_wave=None,
79
  file_spect=None,
80
- cross_fade_duration=0.15,
81
- show_info=print,
82
- progress=tqdm,
83
  ):
84
  wav, sr, spect = infer_process(
85
  ref_file,
86
  ref_text,
87
  gen_text,
88
  self.ema_model,
89
- cross_fade_duration,
90
- speed,
91
- show_info,
92
- progress,
93
- nfe_step,
94
- cfg_strength,
95
- sway_sampling_coef,
96
- fix_duration,
 
97
  )
98
 
99
  if file_wave is not None:
 
69
  ref_file,
70
  ref_text,
71
  gen_text,
72
+ show_info=print,
73
+ progress=tqdm,
74
+ target_rms=0.1,
75
+ cross_fade_duration=0.15,
76
  sway_sampling_coef=-1,
77
  cfg_strength=2,
78
  nfe_step=32,
 
81
  remove_silence=False,
82
  file_wave=None,
83
  file_spect=None,
 
 
 
84
  ):
85
  wav, sr, spect = infer_process(
86
  ref_file,
87
  ref_text,
88
  gen_text,
89
  self.ema_model,
90
+ show_info=show_info,
91
+ progress=progress,
92
+ target_rms=target_rms,
93
+ cross_fade_duration=cross_fade_duration,
94
+ nfe_step=nfe_step,
95
+ cfg_strength=cfg_strength,
96
+ sway_sampling_coef=sway_sampling_coef,
97
+ speed=speed,
98
+ fix_duration=fix_duration,
99
  )
100
 
101
  if file_wave is not None:
model/backbones/dit.py CHANGED
@@ -45,9 +45,9 @@ class TextEmbedding(nn.Module):
45
  self.extra_modeling = False
46
 
47
  def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
48
- batch, text_len = text.shape[0], text.shape[1]
49
  text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
50
  text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
 
51
  text = F.pad(text, (0, seq_len - text_len), value=0)
52
 
53
  if drop_text: # cfg for text
 
45
  self.extra_modeling = False
46
 
47
  def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
 
48
  text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
49
  text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
50
+ batch, text_len = text.shape[0], text.shape[1]
51
  text = F.pad(text, (0, seq_len - text_len), value=0)
52
 
53
  if drop_text: # cfg for text
model/backbones/unett.py CHANGED
@@ -48,9 +48,9 @@ class TextEmbedding(nn.Module):
48
  self.extra_modeling = False
49
 
50
  def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
51
- batch, text_len = text.shape[0], text.shape[1]
52
  text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
53
  text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
 
54
  text = F.pad(text, (0, seq_len - text_len), value=0)
55
 
56
  if drop_text: # cfg for text
 
48
  self.extra_modeling = False
49
 
50
  def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
 
51
  text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
52
  text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
53
+ batch, text_len = text.shape[0], text.shape[1]
54
  text = F.pad(text, (0, seq_len - text_len), value=0)
55
 
56
  if drop_text: # cfg for text
model/utils_infer.py CHANGED
@@ -31,12 +31,13 @@ target_sample_rate = 24000
31
  n_mel_channels = 100
32
  hop_length = 256
33
  target_rms = 0.1
34
- # nfe_step = 32 # 16, 32
35
- # cfg_strength = 2.0
36
- # ode_method = "euler"
37
- # sway_sampling_coef = -1.0
38
- # speed = 1.0
39
- # fix_duration = None
 
40
 
41
  # -----------------------------------------
42
 
@@ -107,7 +108,7 @@ def initialize_asr_pipeline(device=device):
107
  # load model for inference
108
 
109
 
110
- def load_model(model_cls, model_cfg, ckpt_path, vocab_file="", ode_method="euler", use_ema=True, device=device):
111
  if vocab_file == "":
112
  vocab_file = "Emilia_ZH_EN"
113
  tokenizer = "pinyin"
@@ -192,14 +193,15 @@ def infer_process(
192
  ref_text,
193
  gen_text,
194
  model_obj,
195
- cross_fade_duration=0.15,
196
- speed=1.0,
197
  show_info=print,
198
  progress=tqdm,
199
- nfe_step=32,
200
- cfg_strength=2,
201
- sway_sampling_coef=-1,
202
- fix_duration=None,
 
 
 
203
  ):
204
  # Split the input text into batches
205
  audio, sr = torchaudio.load(ref_audio)
@@ -214,13 +216,14 @@ def infer_process(
214
  ref_text,
215
  gen_text_batches,
216
  model_obj,
217
- cross_fade_duration,
218
- speed,
219
- progress,
220
- nfe_step,
221
- cfg_strength,
222
- sway_sampling_coef,
223
- fix_duration,
 
224
  )
225
 
226
 
@@ -232,12 +235,13 @@ def infer_batch_process(
232
  ref_text,
233
  gen_text_batches,
234
  model_obj,
235
- cross_fade_duration=0.15,
236
- speed=1,
237
  progress=tqdm,
 
 
238
  nfe_step=32,
239
  cfg_strength=2.0,
240
  sway_sampling_coef=-1,
 
241
  fix_duration=None,
242
  ):
243
  audio, sr = ref_audio
@@ -262,11 +266,11 @@ def infer_batch_process(
262
  text_list = [ref_text + gen_text]
263
  final_text_list = convert_char_to_pinyin(text_list)
264
 
 
265
  if fix_duration is not None:
266
  duration = int(fix_duration * target_sample_rate / hop_length)
267
  else:
268
  # Calculate duration
269
- ref_audio_len = audio.shape[-1] // hop_length
270
  ref_text_len = len(ref_text.encode("utf-8"))
271
  gen_text_len = len(gen_text.encode("utf-8"))
272
  duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
 
31
  n_mel_channels = 100
32
  hop_length = 256
33
  target_rms = 0.1
34
+ cross_fade_duration = 0.15
35
+ ode_method = "euler"
36
+ nfe_step = 32 # 16, 32
37
+ cfg_strength = 2.0
38
+ sway_sampling_coef = -1.0
39
+ speed = 1.0
40
+ fix_duration = None
41
 
42
  # -----------------------------------------
43
 
 
108
  # load model for inference
109
 
110
 
111
+ def load_model(model_cls, model_cfg, ckpt_path, vocab_file="", ode_method=ode_method, use_ema=True, device=device):
112
  if vocab_file == "":
113
  vocab_file = "Emilia_ZH_EN"
114
  tokenizer = "pinyin"
 
193
  ref_text,
194
  gen_text,
195
  model_obj,
 
 
196
  show_info=print,
197
  progress=tqdm,
198
+ target_rms=target_rms,
199
+ cross_fade_duration=cross_fade_duration,
200
+ nfe_step=nfe_step,
201
+ cfg_strength=cfg_strength,
202
+ sway_sampling_coef=sway_sampling_coef,
203
+ speed=speed,
204
+ fix_duration=fix_duration,
205
  ):
206
  # Split the input text into batches
207
  audio, sr = torchaudio.load(ref_audio)
 
216
  ref_text,
217
  gen_text_batches,
218
  model_obj,
219
+ progress=progress,
220
+ target_rms=target_rms,
221
+ cross_fade_duration=cross_fade_duration,
222
+ nfe_step=nfe_step,
223
+ cfg_strength=cfg_strength,
224
+ sway_sampling_coef=sway_sampling_coef,
225
+ speed=speed,
226
+ fix_duration=fix_duration,
227
  )
228
 
229
 
 
235
  ref_text,
236
  gen_text_batches,
237
  model_obj,
 
 
238
  progress=tqdm,
239
+ target_rms=0.1,
240
+ cross_fade_duration=0.15,
241
  nfe_step=32,
242
  cfg_strength=2.0,
243
  sway_sampling_coef=-1,
244
+ speed=1,
245
  fix_duration=None,
246
  ):
247
  audio, sr = ref_audio
 
266
  text_list = [ref_text + gen_text]
267
  final_text_list = convert_char_to_pinyin(text_list)
268
 
269
+ ref_audio_len = audio.shape[-1] // hop_length
270
  if fix_duration is not None:
271
  duration = int(fix_duration * target_sample_rate / hop_length)
272
  else:
273
  # Calculate duration
 
274
  ref_text_len = len(ref_text.encode("utf-8"))
275
  gen_text_len = len(gen_text.encode("utf-8"))
276
  duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)