SWivid commited on
Commit
e6226de
1 Parent(s): b53eca8

Update app_local.py

Browse files

Mainly redirect to split ckpt repos, along with some minor updates
fix: "gen_text" -> "chunk"

Files changed (1) hide show
  1. app_local.py +39 -22
app_local.py CHANGED
@@ -10,7 +10,7 @@ import tempfile
10
  from einops import rearrange
11
  from ema_pytorch import EMA
12
  from vocos import Vocos
13
- from pydub import AudioSegment
14
  from model import CFM, UNetT, DiT, MMDiT
15
  from cached_path import cached_path
16
  from model.utils import (
@@ -20,6 +20,7 @@ from model.utils import (
20
  )
21
  from transformers import pipeline
22
  import librosa
 
23
  from txtsplit import txtsplit
24
 
25
  device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
@@ -31,6 +32,8 @@ pipe = pipeline(
31
  device=device,
32
  )
33
 
 
 
34
  # --------------------- Settings -------------------- #
35
 
36
  target_sample_rate = 24000
@@ -45,8 +48,8 @@ speed = 1.0
45
  # fix_duration = 27 # None or float (duration in seconds)
46
  fix_duration = None
47
 
48
- def load_model(exp_name, model_cls, model_cfg, ckpt_step):
49
- checkpoint = torch.load(str(cached_path(f"hf://SWivid/F5-TTS/{exp_name}/model_{ckpt_step}.pt")), map_location=device)
50
  vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin")
51
  model = CFM(
52
  transformer=model_cls(
@@ -69,20 +72,26 @@ def load_model(exp_name, model_cls, model_cfg, ckpt_step):
69
  ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
70
  ema_model.copy_params_from_ema_to_model()
71
 
72
- return ema_model, model
73
 
74
  # load models
75
  F5TTS_model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
76
  E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
77
 
78
- F5TTS_ema_model, F5TTS_base_model = load_model("F5TTS_Base", DiT, F5TTS_model_cfg, 1200000)
79
- E2TTS_ema_model, E2TTS_base_model = load_model("E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000)
80
 
81
  def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence, progress = gr.Progress()):
82
  print(gen_text)
83
  gr.Info("Converting audio...")
84
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
85
  aseg = AudioSegment.from_file(ref_audio_orig)
 
 
 
 
 
 
86
  # Convert to mono
87
  aseg = aseg.set_channels(1)
88
  audio_duration = len(aseg)
@@ -93,10 +102,8 @@ def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence, progress
93
  ref_audio = f.name
94
  if exp_name == "F5-TTS":
95
  ema_model = F5TTS_ema_model
96
- base_model = F5TTS_base_model
97
  elif exp_name == "E2-TTS":
98
  ema_model = E2TTS_ema_model
99
- base_model = E2TTS_base_model
100
 
101
  if not ref_text.strip():
102
  gr.Info("No reference text provided, transcribing reference audio...")
@@ -111,6 +118,7 @@ def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence, progress
111
  else:
112
  gr.Info("Using custom reference text...")
113
  audio, sr = torchaudio.load(ref_audio)
 
114
  # Audio
115
  if audio.shape[0] > 1:
116
  audio = torch.mean(audio, dim=0, keepdim=True)
@@ -122,7 +130,7 @@ def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence, progress
122
  audio = resampler(audio)
123
  audio = audio.to(device)
124
  # Chunk
125
- chunks = txtsplit(gen_text, 100, 150) # 100 chars preferred, 150 max
126
  results = []
127
  generated_mel_specs = []
128
  for chunk in progress.tqdm(chunks):
@@ -136,14 +144,14 @@ def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence, progress
136
  # duration = int(fix_duration * target_sample_rate / hop_length)
137
  # else:
138
  zh_pause_punc = r"。,、;:?!"
139
- ref_text_len = len(ref_text) + len(re.findall(zh_pause_punc, ref_text))
140
- gen_text_len = len(gen_text) + len(re.findall(zh_pause_punc, gen_text))
141
- duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
142
 
143
  # inference
144
  gr.Info(f"Generating audio using {exp_name}")
145
  with torch.inference_mode():
146
- generated, _ = base_model.sample(
147
  cond=audio,
148
  text=final_text_list,
149
  duration=duration,
@@ -155,7 +163,6 @@ def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence, progress
155
  generated = generated[:, ref_audio_len:, :]
156
  generated_mel_spec = rearrange(generated, '1 n d -> 1 d n')
157
  gr.Info("Running vocoder")
158
- vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
159
  generated_wave = vocos.decode(generated_mel_spec.cpu())
160
  if rms < target_rms:
161
  generated_wave = generated_wave * rms / target_rms
@@ -166,13 +173,23 @@ def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence, progress
166
  generated_wave = np.concatenate(results)
167
  if remove_silence:
168
  gr.Info("Removing audio silences... This may take a moment")
169
- non_silent_intervals = librosa.effects.split(generated_wave, top_db=30)
170
- non_silent_wave = np.array([])
171
- for interval in non_silent_intervals:
172
- start, end = interval
173
- non_silent_wave = np.concatenate([non_silent_wave, generated_wave[start:end]])
174
- generated_wave = non_silent_wave
175
-
 
 
 
 
 
 
 
 
 
 
176
 
177
  # spectogram
178
  # with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_spectrogram:
@@ -214,6 +231,6 @@ Long-form/batched inference + speech editing is coming soon!
214
 
215
  generate_btn.click(infer, inputs=[ref_audio_input, ref_text_input, gen_text_input, model_choice, remove_silence], outputs=[audio_output])
216
  gr.Markdown("Unofficial demo by [mrfakename](https://x.com/realmrfakename)")
217
-
218
 
219
  app.queue().launch()
 
10
  from einops import rearrange
11
  from ema_pytorch import EMA
12
  from vocos import Vocos
13
+ from pydub import AudioSegment, silence
14
  from model import CFM, UNetT, DiT, MMDiT
15
  from cached_path import cached_path
16
  from model.utils import (
 
20
  )
21
  from transformers import pipeline
22
  import librosa
23
+ import soundfile as sf
24
  from txtsplit import txtsplit
25
 
26
  device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
 
32
  device=device,
33
  )
34
 
35
+ vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
36
+
37
  # --------------------- Settings -------------------- #
38
 
39
  target_sample_rate = 24000
 
48
  # fix_duration = 27 # None or float (duration in seconds)
49
  fix_duration = None
50
 
51
+ def load_model(repo_name, exp_name, model_cls, model_cfg, ckpt_step):
52
+ checkpoint = torch.load(str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.pt")), map_location=device)
53
  vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin")
54
  model = CFM(
55
  transformer=model_cls(
 
72
  ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
73
  ema_model.copy_params_from_ema_to_model()
74
 
75
+ return model
76
 
77
  # load models
78
  F5TTS_model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
79
  E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
80
 
81
+ F5TTS_ema_model = load_model("F5-TTS", "F5TTS_Base", DiT, F5TTS_model_cfg, 1200000)
82
+ E2TTS_ema_model = load_model("E2-TTS", "E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000)
83
 
84
  def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence, progress = gr.Progress()):
85
  print(gen_text)
86
  gr.Info("Converting audio...")
87
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
88
  aseg = AudioSegment.from_file(ref_audio_orig)
89
+ # remove long silence in reference audio
90
+ non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500)
91
+ non_silent_wave = AudioSegment.silent(duration=0)
92
+ for non_silent_seg in non_silent_segs:
93
+ non_silent_wave += non_silent_seg
94
+ aseg = non_silent_wave
95
  # Convert to mono
96
  aseg = aseg.set_channels(1)
97
  audio_duration = len(aseg)
 
102
  ref_audio = f.name
103
  if exp_name == "F5-TTS":
104
  ema_model = F5TTS_ema_model
 
105
  elif exp_name == "E2-TTS":
106
  ema_model = E2TTS_ema_model
 
107
 
108
  if not ref_text.strip():
109
  gr.Info("No reference text provided, transcribing reference audio...")
 
118
  else:
119
  gr.Info("Using custom reference text...")
120
  audio, sr = torchaudio.load(ref_audio)
121
+ max_chars = int(len(ref_text) / (audio.shape[-1] / sr) * (30 - audio.shape[-1] / sr))
122
  # Audio
123
  if audio.shape[0] > 1:
124
  audio = torch.mean(audio, dim=0, keepdim=True)
 
130
  audio = resampler(audio)
131
  audio = audio.to(device)
132
  # Chunk
133
+ chunks = txtsplit(gen_text, 0.7*max_chars, 0.9*max_chars) # 100 chars preferred, 150 max
134
  results = []
135
  generated_mel_specs = []
136
  for chunk in progress.tqdm(chunks):
 
144
  # duration = int(fix_duration * target_sample_rate / hop_length)
145
  # else:
146
  zh_pause_punc = r"。,、;:?!"
147
+ ref_text_len = len(ref_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, ref_text))
148
+ chunk = len(chunk.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, gen_text))
149
+ duration = ref_audio_len + int(ref_audio_len / ref_text_len * chunk / speed)
150
 
151
  # inference
152
  gr.Info(f"Generating audio using {exp_name}")
153
  with torch.inference_mode():
154
+ generated, _ = ema_model.sample(
155
  cond=audio,
156
  text=final_text_list,
157
  duration=duration,
 
163
  generated = generated[:, ref_audio_len:, :]
164
  generated_mel_spec = rearrange(generated, '1 n d -> 1 d n')
165
  gr.Info("Running vocoder")
 
166
  generated_wave = vocos.decode(generated_mel_spec.cpu())
167
  if rms < target_rms:
168
  generated_wave = generated_wave * rms / target_rms
 
173
  generated_wave = np.concatenate(results)
174
  if remove_silence:
175
  gr.Info("Removing audio silences... This may take a moment")
176
+ # non_silent_intervals = librosa.effects.split(generated_wave, top_db=30)
177
+ # non_silent_wave = np.array([])
178
+ # for interval in non_silent_intervals:
179
+ # start, end = interval
180
+ # non_silent_wave = np.concatenate([non_silent_wave, generated_wave[start:end]])
181
+ # generated_wave = non_silent_wave
182
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
183
+ sf.write(f.name, generated_wave, target_sample_rate)
184
+ aseg = AudioSegment.from_file(f.name)
185
+ non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500)
186
+ non_silent_wave = AudioSegment.silent(duration=0)
187
+ for non_silent_seg in non_silent_segs:
188
+ non_silent_wave += non_silent_seg
189
+ aseg = non_silent_wave
190
+ aseg.export(f.name, format="wav")
191
+ generated_wave, _ = torchaudio.load(f.name)
192
+ generated_wave = generated_wave.squeeze().cpu().numpy()
193
 
194
  # spectogram
195
  # with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_spectrogram:
 
231
 
232
  generate_btn.click(infer, inputs=[ref_audio_input, ref_text_input, gen_text_input, model_choice, remove_silence], outputs=[audio_output])
233
  gr.Markdown("Unofficial demo by [mrfakename](https://x.com/realmrfakename)")
234
+
235
 
236
  app.queue().launch()