danielr-ceva commited on
Commit
c4d3070
·
verified ·
1 Parent(s): cba8d01

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -73
app.py CHANGED
@@ -52,6 +52,7 @@ def discover_model_presets() -> Dict[str, ModelSpec]:
52
  "dpdfnet4",
53
  "dpdfnet8",
54
  "dpdfnet2_48khz_hr",
 
55
  ]
56
 
57
  found_paths = {p.stem: p for p in ONNX_DIR.glob("*.onnx") if p.is_file()}
@@ -125,35 +126,34 @@ def get_ort_session(model_key: str) -> ort.InferenceSession:
125
  return sess
126
 
127
 
128
- def _resolve_state_path(model_key: str) -> Path:
129
- spec = MODEL_PRESETS[model_key]
130
- model_path = Path(spec.onnx_path)
131
- state_path = model_path.with_name(f"{model_path.stem}_state.npz")
132
- if not state_path.is_file():
133
- raise gr.Error(f"State file not found: {state_path}")
134
- return state_path
135
-
136
-
137
  def _load_initial_state(model_key: str, session: ort.InferenceSession) -> np.ndarray:
138
  if model_key in _INIT_STATES:
139
  return _INIT_STATES[model_key]
140
 
141
- state_path = _resolve_state_path(model_key)
142
- with np.load(state_path) as data:
143
- if "init_state" not in data:
144
- raise gr.Error(f"Missing 'init_state' key in state file: {state_path}")
145
- init_state = np.ascontiguousarray(data["init_state"].astype(np.float32, copy=False))
146
 
147
- expected_shape = session.get_inputs()[1].shape
148
- if len(expected_shape) != init_state.ndim:
 
 
 
 
 
 
 
 
 
 
149
  raise gr.Error(
150
- f"Initial state rank mismatch for {state_path.name}: expected={expected_shape}, got={tuple(init_state.shape)}"
 
151
  )
152
- for exp_dim, act_dim in zip(expected_shape, init_state.shape):
153
- if isinstance(exp_dim, int) and exp_dim != act_dim:
154
- raise gr.Error(
155
- f"Initial state shape mismatch for {state_path.name}: expected={expected_shape}, got={tuple(init_state.shape)}"
156
- )
157
 
158
  _INIT_STATES[model_key] = init_state
159
  return init_state
@@ -170,11 +170,7 @@ def vorbis_window(window_len: int) -> np.ndarray:
170
  return window.astype(np.float32)
171
 
172
 
173
- def get_wnorm(window_len: int, frame_size: int) -> float:
174
- return 1.0 / (window_len ** 2 / (2 * frame_size))
175
-
176
-
177
- def _infer_stft_params(model_key: str, session: ort.InferenceSession) -> Tuple[int, int, float, np.ndarray]:
178
  # ONNX spec input is [B, T, F, 2] (or dynamic variants).
179
  spec_shape = session.get_inputs()[0].shape
180
  freq_bins = spec_shape[-2] if len(spec_shape) >= 2 else None
@@ -188,11 +184,10 @@ def _infer_stft_params(model_key: str, session: ort.InferenceSession) -> Tuple[i
188
 
189
  hop = win_len // 2
190
  win = vorbis_window(win_len)
191
- wnorm = get_wnorm(win_len, hop)
192
- return win_len, hop, wnorm, win
193
 
194
 
195
- def _preprocess_waveform(waveform: np.ndarray, win_len: int, hop: int, wnorm: float, win: np.ndarray) -> np.ndarray:
196
  audio = np.asarray(waveform, dtype=np.float32).reshape(-1)
197
  audio_pad = np.pad(audio, (0, win_len), mode="constant")
198
 
@@ -205,12 +200,12 @@ def _preprocess_waveform(waveform: np.ndarray, win_len: int, hop: int, wnorm: fl
205
  center=True,
206
  pad_mode="reflect",
207
  )
208
- spec = (spec.T * wnorm).astype(np.complex64, copy=False) # [T, F]
209
  spec_ri = np.stack([spec.real, spec.imag], axis=-1).astype(np.float32, copy=False) # [T, F, 2]
210
- return spec_ri[None, ...] # [1, T, F, 2]
211
 
212
 
213
- def _postprocess_spec(spec_e: np.ndarray, win_len: int, hop: int, wnorm: float, win: np.ndarray) -> np.ndarray:
214
  spec_c = np.asarray(spec_e[0], dtype=np.float32) # [T, F, 2]
215
  spec = (spec_c[..., 0] + 1j * spec_c[..., 1]).T.astype(np.complex64, copy=False) # [F, T]
216
 
@@ -223,12 +218,10 @@ def _postprocess_spec(spec_e: np.ndarray, win_len: int, hop: int, wnorm: float,
223
  length=None,
224
  ).astype(np.float32, copy=False)
225
 
226
- waveform_e = waveform_e / wnorm
227
- waveform_e = np.concatenate(
228
  [waveform_e[win_len * 2 :], np.zeros(win_len * 2, dtype=np.float32)],
229
  axis=0,
230
  )
231
- return waveform_e
232
 
233
 
234
  # -----------------------------
@@ -253,8 +246,8 @@ def enhance_audio_onnx(
253
  out_state_name = outputs[1].name
254
 
255
  waveform = np.asarray(audio_mono, dtype=np.float32).reshape(-1)
256
- win_len, hop, wnorm, win = _infer_stft_params(model_key, sess)
257
- spec_r_np = _preprocess_waveform(waveform, win_len=win_len, hop=hop, wnorm=wnorm, win=win)
258
 
259
  state = _load_initial_state(model_key, sess).copy()
260
  spec_e_frames = []
@@ -272,7 +265,7 @@ def enhance_audio_onnx(
272
  return waveform
273
 
274
  spec_e_np = np.concatenate(spec_e_frames, axis=1)
275
- waveform_e = _postprocess_spec(spec_e_np, win_len=win_len, hop=hop, wnorm=wnorm, win=win)
276
  return np.asarray(waveform_e, dtype=np.float32).reshape(-1)
277
 
278
 
@@ -428,18 +421,26 @@ CSS = """
428
  border-radius: 16px;
429
  border: 1px solid rgba(0,0,0,0.08);
430
  background: linear-gradient(135deg, rgba(255,152,0,0.14), rgba(255,152,0,0.04));
 
431
  }
432
  #header h1{
433
- margin: 0;
434
  font-size: 24px;
435
  font-weight: 800;
436
  letter-spacing: -0.2px;
437
  }
438
  #header p{
439
- margin: 6px 0 0 0;
 
440
  color: var(--body-text-color-subdued);
441
- font-size: 13.5px;
442
- line-height: 1.35;
 
 
 
 
 
 
443
  }
444
 
445
  .spec img { border-radius: 14px; }
@@ -454,36 +455,14 @@ CSS = """
454
  """
455
 
456
  with gr.Blocks(theme=THEME, css=CSS, title="DPDFNet Speech Enhancement") as demo:
457
- gr.HTML(
458
- # """
459
- # <div id="header">
460
- # <h1>DPDFNet Speech Enhancement</h1>
461
- # <p>
462
- # Upload or record up to 10 seconds. Multi-channel inputs are averaged to mono.
463
- # Choose any local ONNX model from <code>./onnx</code>.
464
- # Pre/postprocessing uses the same non-streaming STFT/iSTFT flow as <code>streaming/infer_dpdfnet_onnx.py</code>.
465
- # </p>
466
- # </div>
467
- # """
468
- """
469
- <div id="header" style="text-align: center; margin-bottom: 25px;">
470
-
471
- <h1 style="margin-bottom: 6px;">DPDFNet Speech Enhancement</h1>
472
-
473
- <p style="font-size: 14px; letter-spacing: 1px; margin-bottom: 14px; color: #555;">
474
- Causal • Real-Time • Edge-Ready
475
- </p>
476
-
477
- <p style="max-width: 720px; margin: 0 auto; font-size: 15px; line-height: 1.6;">
478
- DPDFNet extends DeepFilterNet2 with Dual-Path RNN blocks to improve
479
- long-range temporal and cross-band modeling while preserving low latency.
480
- Designed for single-channel streaming speech enhancement under challenging noise conditions.
481
- </p>
482
-
483
- <hr style="margin-top: 22px; border: none; height: 1px; background: linear-gradient(to right, transparent, #ddd, transparent);">
484
-
485
- </div>
486
- """
487
  )
488
 
489
  with gr.Row():
 
52
  "dpdfnet4",
53
  "dpdfnet8",
54
  "dpdfnet2_48khz_hr",
55
+ "dpdfnet8_48khz_hr",
56
  ]
57
 
58
  found_paths = {p.stem: p for p in ONNX_DIR.glob("*.onnx") if p.is_file()}
 
126
  return sess
127
 
128
 
 
 
 
 
 
 
 
 
 
129
  def _load_initial_state(model_key: str, session: ort.InferenceSession) -> np.ndarray:
130
  if model_key in _INIT_STATES:
131
  return _INIT_STATES[model_key]
132
 
133
+ if len(session.get_inputs()) < 2:
134
+ raise gr.Error("Expected streaming ONNX model with two inputs: (spec, state).")
 
 
 
135
 
136
+ meta = session.get_modelmeta().custom_metadata_map
137
+ try:
138
+ state_size = int(meta["state_size"])
139
+ erb_norm_state_size = int(meta["erb_norm_state_size"])
140
+ spec_norm_state_size = int(meta["spec_norm_state_size"])
141
+ erb_norm_init = np.array(
142
+ [float(x) for x in meta["erb_norm_init"].split(",")], dtype=np.float32
143
+ )
144
+ spec_norm_init = np.array(
145
+ [float(x) for x in meta["spec_norm_init"].split(",")], dtype=np.float32
146
+ )
147
+ except KeyError as exc:
148
  raise gr.Error(
149
+ f"ONNX model is missing required metadata key: {exc}. "
150
+ "Re-export the model to embed state initialisation metadata."
151
  )
152
+
153
+ init_state = np.zeros(state_size, dtype=np.float32)
154
+ init_state[0:erb_norm_state_size] = erb_norm_init
155
+ init_state[erb_norm_state_size:erb_norm_state_size + spec_norm_state_size] = spec_norm_init
156
+ init_state = np.ascontiguousarray(init_state)
157
 
158
  _INIT_STATES[model_key] = init_state
159
  return init_state
 
170
  return window.astype(np.float32)
171
 
172
 
173
+ def _infer_stft_params(model_key: str, session: ort.InferenceSession) -> Tuple[int, int, np.ndarray]:
 
 
 
 
174
  # ONNX spec input is [B, T, F, 2] (or dynamic variants).
175
  spec_shape = session.get_inputs()[0].shape
176
  freq_bins = spec_shape[-2] if len(spec_shape) >= 2 else None
 
184
 
185
  hop = win_len // 2
186
  win = vorbis_window(win_len)
187
+ return win_len, hop, win
 
188
 
189
 
190
+ def _preprocess_waveform(waveform: np.ndarray, win_len: int, hop: int, win: np.ndarray) -> np.ndarray:
191
  audio = np.asarray(waveform, dtype=np.float32).reshape(-1)
192
  audio_pad = np.pad(audio, (0, win_len), mode="constant")
193
 
 
200
  center=True,
201
  pad_mode="reflect",
202
  )
203
+ spec = spec.T.astype(np.complex64, copy=False) # [T, F]
204
  spec_ri = np.stack([spec.real, spec.imag], axis=-1).astype(np.float32, copy=False) # [T, F, 2]
205
+ return np.ascontiguousarray(spec_ri[None, ...], dtype=np.float32) # [1, T, F, 2]
206
 
207
 
208
+ def _postprocess_spec(spec_e: np.ndarray, win_len: int, hop: int, win: np.ndarray) -> np.ndarray:
209
  spec_c = np.asarray(spec_e[0], dtype=np.float32) # [T, F, 2]
210
  spec = (spec_c[..., 0] + 1j * spec_c[..., 1]).T.astype(np.complex64, copy=False) # [F, T]
211
 
 
218
  length=None,
219
  ).astype(np.float32, copy=False)
220
 
221
+ return np.concatenate(
 
222
  [waveform_e[win_len * 2 :], np.zeros(win_len * 2, dtype=np.float32)],
223
  axis=0,
224
  )
 
225
 
226
 
227
  # -----------------------------
 
246
  out_state_name = outputs[1].name
247
 
248
  waveform = np.asarray(audio_mono, dtype=np.float32).reshape(-1)
249
+ win_len, hop, win = _infer_stft_params(model_key, sess)
250
+ spec_r_np = _preprocess_waveform(waveform, win_len=win_len, hop=hop, win=win)
251
 
252
  state = _load_initial_state(model_key, sess).copy()
253
  spec_e_frames = []
 
265
  return waveform
266
 
267
  spec_e_np = np.concatenate(spec_e_frames, axis=1)
268
+ waveform_e = _postprocess_spec(spec_e_np, win_len=win_len, hop=hop, win=win)
269
  return np.asarray(waveform_e, dtype=np.float32).reshape(-1)
270
 
271
 
 
421
  border-radius: 16px;
422
  border: 1px solid rgba(0,0,0,0.08);
423
  background: linear-gradient(135deg, rgba(255,152,0,0.14), rgba(255,152,0,0.04));
424
+ text-align: center;
425
  }
426
  #header h1{
427
+ margin: 0 0 6px 0;
428
  font-size: 24px;
429
  font-weight: 800;
430
  letter-spacing: -0.2px;
431
  }
432
  #header p{
433
+ margin: 6px auto 0 auto;
434
+ max-width: 720px;
435
  color: var(--body-text-color-subdued);
436
+ font-size: 14px;
437
+ line-height: 1.6;
438
+ }
439
+ #header hr{
440
+ margin-top: 18px;
441
+ border: none;
442
+ height: 1px;
443
+ background: linear-gradient(to right, transparent, #ddd, transparent);
444
  }
445
 
446
  .spec img { border-radius: 14px; }
 
455
  """
456
 
457
  with gr.Blocks(theme=THEME, css=CSS, title="DPDFNet Speech Enhancement") as demo:
458
+ gr.Markdown(
459
+ "# DPDFNet Speech Enhancement\n\n"
460
+ "Causal · Real-Time · Edge-Ready\n\n"
461
+ "DPDFNet extends DeepFilterNet2 with Dual-Path RNN blocks to improve "
462
+ "long-range temporal and cross-band modeling while preserving low latency. "
463
+ "Designed for single-channel streaming speech enhancement under challenging noise conditions.\n\n"
464
+ "---",
465
+ elem_id="header",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
466
  )
467
 
468
  with gr.Row():