mrfakename commited on
Commit
45012e5
·
verified ·
1 Parent(s): ca722aa

add chunking

Browse files
Files changed (1) hide show
  1. app.py +52 -46
app.py CHANGED
@@ -19,8 +19,9 @@ from model.utils import (
19
  from transformers import pipeline
20
  import spaces
21
  import librosa
 
22
 
23
- device = "cuda" if torch.cuda.is_available() else "cpu"
24
 
25
  pipe = pipeline(
26
  "automatic-speech-recognition",
@@ -77,7 +78,7 @@ F5TTS_ema_model, F5TTS_base_model = load_model("F5TTS_Base", DiT, F5TTS_model_cf
77
  E2TTS_ema_model, E2TTS_base_model = load_model("E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000)
78
 
79
  @spaces.GPU
80
- def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence):
81
  print(gen_text)
82
  if len(gen_text) > 200:
83
  raise gr.Error("Please keep your text under 200 chars.")
@@ -122,44 +123,49 @@ def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence):
122
  resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
123
  audio = resampler(audio)
124
  audio = audio.to(device)
125
-
126
- # Prepare the text
127
- text_list = [ref_text + gen_text]
128
- final_text_list = convert_char_to_pinyin(text_list)
129
-
130
- # Calculate duration
131
- ref_audio_len = audio.shape[-1] // hop_length
132
- # if fix_duration is not None:
133
- # duration = int(fix_duration * target_sample_rate / hop_length)
134
- # else:
135
- zh_pause_punc = r"。,、;:?!"
136
- ref_text_len = len(ref_text) + len(re.findall(zh_pause_punc, ref_text))
137
- gen_text_len = len(gen_text) + len(re.findall(zh_pause_punc, gen_text))
138
- duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
139
-
140
- # inference
141
- gr.Info(f"Generating audio using {exp_name}")
142
- with torch.inference_mode():
143
- generated, _ = base_model.sample(
144
- cond=audio,
145
- text=final_text_list,
146
- duration=duration,
147
- steps=nfe_step,
148
- cfg_strength=cfg_strength,
149
- sway_sampling_coef=sway_sampling_coef,
150
- )
151
-
152
- generated = generated[:, ref_audio_len:, :]
153
- generated_mel_spec = rearrange(generated, '1 n d -> 1 d n')
154
- gr.Info("Running vocoder")
155
- vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
156
- generated_wave = vocos.decode(generated_mel_spec.cpu())
157
- if rms < target_rms:
158
- generated_wave = generated_wave * rms / target_rms
159
-
160
- # wav -> numpy
161
- generated_wave = generated_wave.squeeze().cpu().numpy()
162
-
 
 
 
 
 
163
  if remove_silence:
164
  gr.Info("Removing audio silences... This may take a moment")
165
  non_silent_intervals = librosa.effects.split(generated_wave, top_db=30)
@@ -171,11 +177,11 @@ def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence):
171
 
172
 
173
  # spectogram
174
- with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_spectrogram:
175
- spectrogram_path = tmp_spectrogram.name
176
- save_spectrogram(generated_mel_spec[0].cpu().numpy(), spectrogram_path)
177
 
178
- return (target_sample_rate, generated_wave), spectrogram_path
179
 
180
  with gr.Blocks() as app:
181
  gr.Markdown("""
@@ -206,9 +212,9 @@ Long-form/batched inference + speech editing is coming soon!
206
  remove_silence = gr.Checkbox(label="Remove Silences", info="The model tends to produce silences, especially on longer audio. We can manually remove silences if needed. Note that this is an experimental feature and may produce strange results. This will also increase generation time.", value=True)
207
 
208
  audio_output = gr.Audio(label="Synthesized Audio")
209
- spectrogram_output = gr.Image(label="Spectrogram")
210
 
211
- generate_btn.click(infer, inputs=[ref_audio_input, ref_text_input, gen_text_input, model_choice, remove_silence], outputs=[audio_output, spectrogram_output])
212
  gr.Markdown("""
213
  ## Run Locally
214
 
 
19
  from transformers import pipeline
20
  import spaces
21
  import librosa
22
+ from txtsplit import txtsplit
23
 
24
+ device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() "cpu"
25
 
26
  pipe = pipeline(
27
  "automatic-speech-recognition",
 
78
  E2TTS_ema_model, E2TTS_base_model = load_model("E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000)
79
 
80
  @spaces.GPU
81
+ def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence, progress = gr.Progress()):
82
  print(gen_text)
83
  if len(gen_text) > 200:
84
  raise gr.Error("Please keep your text under 200 chars.")
 
123
  resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
124
  audio = resampler(audio)
125
  audio = audio.to(device)
126
+ # Chunk
127
+ chunks = txtsplit(gen_text, 100, 150) # 100 chars preferred, 150 max
128
+ results = []
129
+ generated_mel_specs = []
130
+ for chunk in progress.tqdm(chunks):
131
+ # Prepare the text
132
+ text_list = [ref_text + chunk]
133
+ final_text_list = convert_char_to_pinyin(text_list)
134
+
135
+ # Calculate duration
136
+ ref_audio_len = audio.shape[-1] // hop_length
137
+ # if fix_duration is not None:
138
+ # duration = int(fix_duration * target_sample_rate / hop_length)
139
+ # else:
140
+ zh_pause_punc = r"。,、;:?!"
141
+ ref_text_len = len(ref_text) + len(re.findall(zh_pause_punc, ref_text))
142
+ gen_text_len = len(gen_text) + len(re.findall(zh_pause_punc, gen_text))
143
+ duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
144
+
145
+ # inference
146
+ gr.Info(f"Generating audio using {exp_name}")
147
+ with torch.inference_mode():
148
+ generated, _ = base_model.sample(
149
+ cond=audio,
150
+ text=final_text_list,
151
+ duration=duration,
152
+ steps=nfe_step,
153
+ cfg_strength=cfg_strength,
154
+ sway_sampling_coef=sway_sampling_coef,
155
+ )
156
+
157
+ generated = generated[:, ref_audio_len:, :]
158
+ generated_mel_specs.append(rearrange(generated, '1 n d -> 1 d n'))
159
+ gr.Info("Running vocoder")
160
+ vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
161
+ generated_wave = vocos.decode(generated_mel_spec.cpu())
162
+ if rms < target_rms:
163
+ generated_wave = generated_wave * rms / target_rms
164
+
165
+ # wav -> numpy
166
+ generated_wave = generated_wave.squeeze().cpu().numpy()
167
+ results.append(generated_wave)
168
+ generated_wave = np.concatenate(results)
169
  if remove_silence:
170
  gr.Info("Removing audio silences... This may take a moment")
171
  non_silent_intervals = librosa.effects.split(generated_wave, top_db=30)
 
177
 
178
 
179
  # spectogram
180
+ # with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_spectrogram:
181
+ # spectrogram_path = tmp_spectrogram.name
182
+ # save_spectrogram(generated_mel_spec[0].cpu().numpy(), spectrogram_path)
183
 
184
+ return (target_sample_rate, generated_wave)
185
 
186
  with gr.Blocks() as app:
187
  gr.Markdown("""
 
212
  remove_silence = gr.Checkbox(label="Remove Silences", info="The model tends to produce silences, especially on longer audio. We can manually remove silences if needed. Note that this is an experimental feature and may produce strange results. This will also increase generation time.", value=True)
213
 
214
  audio_output = gr.Audio(label="Synthesized Audio")
215
+ # spectrogram_output = gr.Image(label="Spectrogram")
216
 
217
+ generate_btn.click(infer, inputs=[ref_audio_input, ref_text_input, gen_text_input, model_choice, remove_silence], outputs=[audio_output])
218
  gr.Markdown("""
219
  ## Run Locally
220