danielr-ceva commited on
Commit
57d5cf8
·
verified ·
1 Parent(s): cd94843

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -525
app.py DELETED
@@ -1,525 +0,0 @@
1
- import io
2
- import math
3
- import tempfile
4
- from dataclasses import dataclass
5
- from pathlib import Path
6
- from typing import Dict, Optional, Tuple
7
-
8
- import gradio as gr
9
- import librosa
10
- import matplotlib.pyplot as plt
11
- import numpy as np
12
- import onnxruntime as ort
13
- import soundfile as sf
14
- from PIL import Image
15
-
16
- # -----------------------------
17
- # Configuration
18
- # -----------------------------
19
- MAX_SECONDS = 10.0
20
- ONNX_DIR = Path("./onnx")
21
-
22
-
23
- @dataclass(frozen=True)
24
- class ModelSpec:
25
- name: str
26
- sr: int
27
- onnx_path: str
28
-
29
-
30
- # -----------------------------
31
- # Model discovery and metadata
32
- # -----------------------------
33
- def _infer_model_meta(model_name: str) -> int:
34
- normalized = model_name.lower().replace("-", "_")
35
-
36
- if "48khz" in normalized or "48k" in normalized or "48hr" in normalized:
37
- return 48000
38
-
39
- # Fallback for unknown 16 kHz DPDFNet variants
40
- return 16000
41
-
42
-
43
- def _display_label(spec: ModelSpec) -> str:
44
- khz = int(spec.sr // 1000)
45
- return f"{spec.name} ({khz} kHz)"
46
-
47
-
48
- def discover_model_presets() -> Dict[str, ModelSpec]:
49
- ordered_names = [
50
- "baseline",
51
- "dpdfnet2",
52
- "dpdfnet4",
53
- "dpdfnet8",
54
- "dpdfnet2_48khz_hr",
55
- "dpdfnet8_48khz_hr",
56
- ]
57
-
58
- found_paths = {p.stem: p for p in ONNX_DIR.glob("*.onnx") if p.is_file()}
59
- presets: Dict[str, ModelSpec] = {}
60
-
61
- for name in ordered_names:
62
- p = found_paths.get(name)
63
- if p is None:
64
- continue
65
- sr = _infer_model_meta(name)
66
- spec = ModelSpec(
67
- name=name,
68
- sr=sr,
69
- onnx_path=str(p),
70
- )
71
- presets[_display_label(spec)] = spec
72
-
73
- # Include any additional ONNX files not in the canonical order list.
74
- for name, p in sorted(found_paths.items()):
75
- if name in ordered_names:
76
- continue
77
- sr = _infer_model_meta(name)
78
- spec = ModelSpec(
79
- name=name,
80
- sr=sr,
81
- onnx_path=str(p),
82
- )
83
- presets[_display_label(spec)] = spec
84
-
85
- return presets
86
-
87
-
88
- MODEL_PRESETS = discover_model_presets()
89
- DEFAULT_MODEL_KEY = next(iter(MODEL_PRESETS), None)
90
-
91
-
92
- # -----------------------------
93
- # ONNX Runtime + frontend cache
94
- # -----------------------------
95
- _SESSIONS: Dict[str, ort.InferenceSession] = {}
96
- _INIT_STATES: Dict[str, np.ndarray] = {}
97
-
98
-
99
- def resolve_model_path(local_path: str) -> str:
100
- p = Path(local_path)
101
- if p.exists():
102
- return str(p)
103
- raise gr.Error(
104
- f"ONNX model not found at: {local_path}. "
105
- "Expected local models under ./onnx/."
106
- )
107
-
108
-
109
- def get_ort_session(model_key: str) -> ort.InferenceSession:
110
- if model_key in _SESSIONS:
111
- return _SESSIONS[model_key]
112
-
113
- spec = MODEL_PRESETS[model_key]
114
- onnx_path = resolve_model_path(spec.onnx_path)
115
-
116
- options = ort.SessionOptions()
117
- options.intra_op_num_threads = 1
118
- options.inter_op_num_threads = 1
119
-
120
- sess = ort.InferenceSession(
121
- onnx_path,
122
- sess_options=options,
123
- providers=["CPUExecutionProvider"],
124
- )
125
- _SESSIONS[model_key] = sess
126
- return sess
127
-
128
-
129
- def _load_initial_state(model_key: str, session: ort.InferenceSession) -> np.ndarray:
130
- if model_key in _INIT_STATES:
131
- return _INIT_STATES[model_key]
132
-
133
- if len(session.get_inputs()) < 2:
134
- raise gr.Error("Expected streaming ONNX model with two inputs: (spec, state).")
135
-
136
- meta = session.get_modelmeta().custom_metadata_map
137
- try:
138
- state_size = int(meta["state_size"])
139
- erb_norm_state_size = int(meta["erb_norm_state_size"])
140
- spec_norm_state_size = int(meta["spec_norm_state_size"])
141
- erb_norm_init = np.array(
142
- [float(x) for x in meta["erb_norm_init"].split(",")], dtype=np.float32
143
- )
144
- spec_norm_init = np.array(
145
- [float(x) for x in meta["spec_norm_init"].split(",")], dtype=np.float32
146
- )
147
- except KeyError as exc:
148
- raise gr.Error(
149
- f"ONNX model is missing required metadata key: {exc}. "
150
- "Re-export the model to embed state initialisation metadata."
151
- )
152
-
153
- init_state = np.zeros(state_size, dtype=np.float32)
154
- init_state[0:erb_norm_state_size] = erb_norm_init
155
- init_state[erb_norm_state_size:erb_norm_state_size + spec_norm_state_size] = spec_norm_init
156
- init_state = np.ascontiguousarray(init_state)
157
-
158
- _INIT_STATES[model_key] = init_state
159
- return init_state
160
-
161
-
162
- # -----------------------------
163
- # STFT/iSTFT (module-free)
164
- # -----------------------------
165
- def vorbis_window(window_len: int) -> np.ndarray:
166
- window_size_h = window_len / 2
167
- indices = np.arange(window_len)
168
- sin = np.sin(0.5 * np.pi * (indices + 0.5) / window_size_h)
169
- window = np.sin(0.5 * np.pi * sin * sin)
170
- return window.astype(np.float32)
171
-
172
-
173
- def _infer_stft_params(model_key: str, session: ort.InferenceSession) -> Tuple[int, int, np.ndarray]:
174
- # ONNX spec input is [B, T, F, 2] (or dynamic variants).
175
- spec_shape = session.get_inputs()[0].shape
176
- freq_bins = spec_shape[-2] if len(spec_shape) >= 2 else None
177
-
178
- if isinstance(freq_bins, int) and freq_bins > 1:
179
- win_len = int((freq_bins - 1) * 2)
180
- else:
181
- # 20 ms windows for DPDFNet family.
182
- sr = MODEL_PRESETS[model_key].sr
183
- win_len = int(round(sr * 0.02))
184
-
185
- hop = win_len // 2
186
- win = vorbis_window(win_len)
187
- return win_len, hop, win
188
-
189
-
190
- def _preprocess_waveform(waveform: np.ndarray, win_len: int, hop: int, win: np.ndarray) -> np.ndarray:
191
- audio = np.asarray(waveform, dtype=np.float32).reshape(-1)
192
- audio_pad = np.pad(audio, (0, win_len), mode="constant")
193
-
194
- spec = librosa.stft(
195
- y=audio_pad,
196
- n_fft=win_len,
197
- hop_length=hop,
198
- win_length=win_len,
199
- window=win,
200
- center=True,
201
- pad_mode="reflect",
202
- )
203
- spec = spec.T.astype(np.complex64, copy=False) # [T, F]
204
- spec_ri = np.stack([spec.real, spec.imag], axis=-1).astype(np.float32, copy=False) # [T, F, 2]
205
- return np.ascontiguousarray(spec_ri[None, ...], dtype=np.float32) # [1, T, F, 2]
206
-
207
-
208
- def _postprocess_spec(spec_e: np.ndarray, win_len: int, hop: int, win: np.ndarray) -> np.ndarray:
209
- spec_c = np.asarray(spec_e[0], dtype=np.float32) # [T, F, 2]
210
- spec = (spec_c[..., 0] + 1j * spec_c[..., 1]).T.astype(np.complex64, copy=False) # [F, T]
211
-
212
- waveform_e = librosa.istft(
213
- spec,
214
- hop_length=hop,
215
- win_length=win_len,
216
- window=win,
217
- center=True,
218
- length=None,
219
- ).astype(np.float32, copy=False)
220
-
221
- return np.concatenate(
222
- [waveform_e[win_len * 2 :], np.zeros(win_len * 2, dtype=np.float32)],
223
- axis=0,
224
- )
225
-
226
-
227
- # -----------------------------
228
- # ONNX inference (non-streaming pre/post, streaming ONNX state loop)
229
- # -----------------------------
230
- def enhance_audio_onnx(
231
- audio_mono: np.ndarray,
232
- model_key: str,
233
- ) -> np.ndarray:
234
- sess = get_ort_session(model_key)
235
-
236
- inputs = sess.get_inputs()
237
- outputs = sess.get_outputs()
238
- if len(inputs) < 2 or len(outputs) < 2:
239
- raise gr.Error(
240
- "Expected streaming ONNX signature with 2 inputs (spec, state) and 2 outputs (spec_e, state_out)."
241
- )
242
-
243
- in_spec_name = inputs[0].name
244
- in_state_name = inputs[1].name
245
- out_spec_name = outputs[0].name
246
- out_state_name = outputs[1].name
247
-
248
- waveform = np.asarray(audio_mono, dtype=np.float32).reshape(-1)
249
- win_len, hop, win = _infer_stft_params(model_key, sess)
250
- spec_r_np = _preprocess_waveform(waveform, win_len=win_len, hop=hop, win=win)
251
-
252
- state = _load_initial_state(model_key, sess).copy()
253
- spec_e_frames = []
254
- num_frames = int(spec_r_np.shape[1])
255
-
256
- for t in range(num_frames):
257
- spec_t = np.ascontiguousarray(spec_r_np[:, t : t + 1, :, :], dtype=np.float32)
258
- spec_e_t, state = sess.run(
259
- [out_spec_name, out_state_name],
260
- {in_spec_name: spec_t, in_state_name: state},
261
- )
262
- spec_e_frames.append(np.ascontiguousarray(spec_e_t, dtype=np.float32))
263
-
264
- if not spec_e_frames:
265
- return waveform
266
-
267
- spec_e_np = np.concatenate(spec_e_frames, axis=1)
268
- waveform_e = _postprocess_spec(spec_e_np, win_len=win_len, hop=hop, win=win)
269
- return np.asarray(waveform_e, dtype=np.float32).reshape(-1)
270
-
271
-
272
- # -----------------------------
273
- # Audio utilities
274
- # -----------------------------
275
- def _load_wav_from_gradio_path(path: str) -> Tuple[np.ndarray, int]:
276
- data, sr = sf.read(path, always_2d=True)
277
- data = data.astype(np.float32, copy=False)
278
- return data, int(sr)
279
-
280
-
281
- def _to_mono(x: np.ndarray) -> Tuple[np.ndarray, int]:
282
- if x.ndim == 1:
283
- return x.astype(np.float32, copy=False), 1
284
- if x.shape[1] == 1:
285
- return x[:, 0], 1
286
- return x.mean(axis=1), int(x.shape[1])
287
-
288
-
289
- def _resample(y: np.ndarray, sr_in: int, sr_out: int) -> np.ndarray:
290
- if sr_in == sr_out:
291
- return y
292
- return librosa.resample(y, orig_sr=sr_in, target_sr=sr_out).astype(np.float32, copy=False)
293
-
294
-
295
- def _match_length(y: np.ndarray, target_len: int) -> np.ndarray:
296
- if len(y) == target_len:
297
- return y
298
- if len(y) > target_len:
299
- return y[:target_len]
300
- out = np.zeros((target_len,), dtype=y.dtype)
301
- out[: len(y)] = y
302
- return out
303
-
304
-
305
- def _save_wav(y: np.ndarray, sr: int, prefix: str) -> str:
306
- tmp = tempfile.NamedTemporaryFile(prefix=prefix, suffix=".wav", delete=False)
307
- tmp.close()
308
- sf.write(tmp.name, y, sr)
309
- return tmp.name
310
-
311
-
312
- def _spectrogram_image(y: np.ndarray, sr: int) -> Image.Image:
313
- win_length = max(256, int(0.032 * sr))
314
- hop_length = max(64, int(0.008 * sr))
315
- n_fft = 1 << (int(math.ceil(math.log2(win_length))))
316
-
317
- S = librosa.stft(y, n_fft=n_fft, hop_length=hop_length, win_length=win_length, center=False)
318
- S_db = librosa.amplitude_to_db(np.abs(S) + 1e-10, ref=np.max)
319
-
320
- fig, ax = plt.subplots(figsize=(8.4, 3.2))
321
- ax.imshow(S_db, origin="lower", aspect="auto")
322
- ax.set_axis_off()
323
- fig.subplots_adjust(left=0, right=1, top=1, bottom=0)
324
-
325
- buf = io.BytesIO()
326
- fig.savefig(buf, format="png", dpi=160)
327
- plt.close(fig)
328
- buf.seek(0)
329
- return Image.open(buf)
330
-
331
-
332
- # -----------------------------
333
- # Main pipeline
334
- # -----------------------------
335
- def run_enhancement(
336
- source: str,
337
- mic_path: Optional[str],
338
- file_path: Optional[str],
339
- model_key: str,
340
- ):
341
- if not MODEL_PRESETS:
342
- raise gr.Error("No ONNX models found under ./onnx/. Add models and retry.")
343
-
344
- chosen_path = mic_path if source == "Microphone" else file_path
345
- if not chosen_path:
346
- raise gr.Error("Please provide audio either from the microphone or by uploading a file.")
347
-
348
- x, sr_orig = _load_wav_from_gradio_path(chosen_path)
349
- y_mono, n_ch = _to_mono(x)
350
-
351
- max_samples = int(MAX_SECONDS * sr_orig)
352
- was_trimmed = len(y_mono) > max_samples
353
- if was_trimmed:
354
- y_mono = y_mono[:max_samples]
355
- dur = len(y_mono) / float(sr_orig)
356
-
357
- spec = MODEL_PRESETS[model_key]
358
- sr_model = spec.sr
359
-
360
- y_model = _resample(y_mono, sr_orig, sr_model)
361
- y_enh_model = enhance_audio_onnx(y_model, model_key)
362
-
363
- y_enh = _resample(y_enh_model, sr_model, sr_orig)
364
- y_enh = _match_length(y_enh, len(y_mono))
365
-
366
- noisy_out = _save_wav(y_mono, sr_orig, prefix="noisy_mono_")
367
- enh_out = _save_wav(y_enh, sr_orig, prefix="enhanced_")
368
-
369
- noisy_img = _spectrogram_image(y_mono, sr_orig)
370
- enh_img = _spectrogram_image(y_enh, sr_orig)
371
-
372
- status = (
373
- f"**Input:** {sr_orig} Hz, {dur:.2f}s, channels={n_ch} ⭢ mono\n\n"
374
- f"**Model:** {spec.name} (runs at {sr_model} Hz)\n\n"
375
- + (
376
- f"**Resampling:** {sr_orig} ⭢ {sr_model} ⭢ {sr_orig}\n\n"
377
- if sr_orig != sr_model
378
- else "**Resampling:** none\n\n"
379
- )
380
- + (f"**Trimmed:** first {MAX_SECONDS:.0f}s used\n" if was_trimmed else "")
381
- + "\n✅ Done."
382
- )
383
- return noisy_out, enh_out, noisy_img, enh_img, status
384
-
385
-
386
- def set_source_visibility(source: str):
387
- return (
388
- gr.update(visible=(source == "Microphone")),
389
- gr.update(visible=(source == "Upload")),
390
- )
391
-
392
-
393
- # -----------------------------
394
- # UI (light polish)
395
- # -----------------------------
396
- THEME = gr.themes.Soft(
397
- primary_hue="orange",
398
- neutral_hue="slate",
399
- font=[
400
- "Arial",
401
- "ui-sans-serif",
402
- "system-ui",
403
- "Segoe UI",
404
- "Roboto",
405
- "Helvetica Neue",
406
- "Noto Sans",
407
- "Liberation Sans",
408
- "sans-serif",
409
- ],
410
- )
411
-
412
- CSS = """
413
- .gradio-container{
414
- max-width: 1040px !important;
415
- margin: 0 auto !important;
416
- font-family: Arial, ui-sans-serif, system-ui, -apple-system, Segoe UI, Roboto, Helvetica Neue, Noto Sans, Liberation Sans, sans-serif !important;
417
- }
418
-
419
- #header {
420
- padding: 14px 16px;
421
- border-radius: 16px;
422
- border: 1px solid rgba(0,0,0,0.08);
423
- background: linear-gradient(135deg, rgba(255,152,0,0.14), rgba(255,152,0,0.04));
424
- text-align: center;
425
- }
426
- #header h1{
427
- margin: 0 0 6px 0;
428
- font-size: 24px;
429
- font-weight: 800;
430
- letter-spacing: -0.2px;
431
- }
432
- #header p{
433
- margin: 6px auto 0 auto;
434
- max-width: 720px;
435
- color: var(--body-text-color-subdued);
436
- font-size: 14px;
437
- line-height: 1.6;
438
- }
439
- #header hr{
440
- margin-top: 18px;
441
- border: none;
442
- height: 1px;
443
- background: linear-gradient(to right, transparent, #ddd, transparent);
444
- }
445
-
446
- .spec img { border-radius: 14px; }
447
- .audio { border-radius: 14px !important; overflow: hidden; }
448
-
449
- #run_btn{
450
- border-radius: 12px !important;
451
- font-weight: 800 !important;
452
- }
453
-
454
- #status_md p{ margin: 0.35rem 0; }
455
- """
456
-
457
- with gr.Blocks(theme=THEME, css=CSS, title="DPDFNet Speech Enhancement") as demo:
458
- gr.Markdown(
459
- "# DPDFNet Speech Enhancement\n\n"
460
- "Causal · Real-Time · Edge-Ready\n\n"
461
- "DPDFNet extends DeepFilterNet2 with Dual-Path RNN blocks to improve "
462
- "long-range temporal and cross-band modeling while preserving low latency. "
463
- "Designed for single-channel streaming speech enhancement under challenging noise conditions.\n\n"
464
- "---",
465
- elem_id="header",
466
- )
467
-
468
- with gr.Row():
469
- model_key = gr.Dropdown(
470
- choices=list(MODEL_PRESETS.keys()),
471
- value=DEFAULT_MODEL_KEY,
472
- label="Model",
473
- # info="Audio is resampled to model SR, enhanced with ONNX, then resampled back.",
474
- interactive=True,
475
- )
476
-
477
- source = gr.Radio(
478
- choices=["Microphone", "Upload"],
479
- value="Upload",
480
- label="Input source",
481
- )
482
-
483
- with gr.Row():
484
- mic_audio = gr.Audio(
485
- sources=["microphone"],
486
- type="filepath",
487
- format="wav",
488
- label="Microphone (max 10s)",
489
- visible=False,
490
- buttons=["download"],
491
- elem_classes=["audio"],
492
- )
493
- file_audio = gr.Audio(
494
- sources=["upload"],
495
- type="filepath",
496
- format="wav",
497
- label="Upload file (WAV/MP3/FLAC etc., max 10s)",
498
- visible=True,
499
- buttons=["download"],
500
- elem_classes=["audio"],
501
- )
502
-
503
- run_btn = gr.Button("Enhance", variant="primary", elem_id="run_btn")
504
- status = gr.Markdown(elem_id="status_md")
505
-
506
- gr.Markdown("## Results")
507
-
508
- with gr.Row():
509
- out_noisy = gr.Audio(label="Before (mono)", interactive=False, format="wav", buttons=["download"], elem_classes=["audio"])
510
- out_enh = gr.Audio(label="After (enhanced)", interactive=False, format="wav", buttons=["download"], elem_classes=["audio"])
511
-
512
- with gr.Row():
513
- img_noisy = gr.Image(label="Noisy spectrogram", elem_classes=["spec"])
514
- img_enh = gr.Image(label="Enhanced spectrogram", elem_classes=["spec"])
515
-
516
- source.change(fn=set_source_visibility, inputs=source, outputs=[mic_audio, file_audio])
517
- run_btn.click(
518
- fn=run_enhancement,
519
- inputs=[source, mic_audio, file_audio, model_key],
520
- outputs=[out_noisy, out_enh, img_noisy, img_enh, status],
521
- api_name="enhance",
522
- )
523
-
524
- if __name__ == "__main__":
525
- demo.queue(max_size=32).launch()