Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -52,6 +52,7 @@ def discover_model_presets() -> Dict[str, ModelSpec]:
|
|
| 52 |
"dpdfnet4",
|
| 53 |
"dpdfnet8",
|
| 54 |
"dpdfnet2_48khz_hr",
|
|
|
|
| 55 |
]
|
| 56 |
|
| 57 |
found_paths = {p.stem: p for p in ONNX_DIR.glob("*.onnx") if p.is_file()}
|
|
@@ -125,35 +126,34 @@ def get_ort_session(model_key: str) -> ort.InferenceSession:
|
|
| 125 |
return sess
|
| 126 |
|
| 127 |
|
| 128 |
-
def _resolve_state_path(model_key: str) -> Path:
|
| 129 |
-
spec = MODEL_PRESETS[model_key]
|
| 130 |
-
model_path = Path(spec.onnx_path)
|
| 131 |
-
state_path = model_path.with_name(f"{model_path.stem}_state.npz")
|
| 132 |
-
if not state_path.is_file():
|
| 133 |
-
raise gr.Error(f"State file not found: {state_path}")
|
| 134 |
-
return state_path
|
| 135 |
-
|
| 136 |
-
|
| 137 |
def _load_initial_state(model_key: str, session: ort.InferenceSession) -> np.ndarray:
|
| 138 |
if model_key in _INIT_STATES:
|
| 139 |
return _INIT_STATES[model_key]
|
| 140 |
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
if "init_state" not in data:
|
| 144 |
-
raise gr.Error(f"Missing 'init_state' key in state file: {state_path}")
|
| 145 |
-
init_state = np.ascontiguousarray(data["init_state"].astype(np.float32, copy=False))
|
| 146 |
|
| 147 |
-
|
| 148 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
raise gr.Error(
|
| 150 |
-
f"
|
|
|
|
| 151 |
)
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
|
| 158 |
_INIT_STATES[model_key] = init_state
|
| 159 |
return init_state
|
|
@@ -170,11 +170,7 @@ def vorbis_window(window_len: int) -> np.ndarray:
|
|
| 170 |
return window.astype(np.float32)
|
| 171 |
|
| 172 |
|
| 173 |
-
def
|
| 174 |
-
return 1.0 / (window_len ** 2 / (2 * frame_size))
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
def _infer_stft_params(model_key: str, session: ort.InferenceSession) -> Tuple[int, int, float, np.ndarray]:
|
| 178 |
# ONNX spec input is [B, T, F, 2] (or dynamic variants).
|
| 179 |
spec_shape = session.get_inputs()[0].shape
|
| 180 |
freq_bins = spec_shape[-2] if len(spec_shape) >= 2 else None
|
|
@@ -188,11 +184,10 @@ def _infer_stft_params(model_key: str, session: ort.InferenceSession) -> Tuple[i
|
|
| 188 |
|
| 189 |
hop = win_len // 2
|
| 190 |
win = vorbis_window(win_len)
|
| 191 |
-
|
| 192 |
-
return win_len, hop, wnorm, win
|
| 193 |
|
| 194 |
|
| 195 |
-
def _preprocess_waveform(waveform: np.ndarray, win_len: int, hop: int,
|
| 196 |
audio = np.asarray(waveform, dtype=np.float32).reshape(-1)
|
| 197 |
audio_pad = np.pad(audio, (0, win_len), mode="constant")
|
| 198 |
|
|
@@ -205,12 +200,12 @@ def _preprocess_waveform(waveform: np.ndarray, win_len: int, hop: int, wnorm: fl
|
|
| 205 |
center=True,
|
| 206 |
pad_mode="reflect",
|
| 207 |
)
|
| 208 |
-
spec =
|
| 209 |
spec_ri = np.stack([spec.real, spec.imag], axis=-1).astype(np.float32, copy=False) # [T, F, 2]
|
| 210 |
-
return spec_ri[None, ...] # [1, T, F, 2]
|
| 211 |
|
| 212 |
|
| 213 |
-
def _postprocess_spec(spec_e: np.ndarray, win_len: int, hop: int,
|
| 214 |
spec_c = np.asarray(spec_e[0], dtype=np.float32) # [T, F, 2]
|
| 215 |
spec = (spec_c[..., 0] + 1j * spec_c[..., 1]).T.astype(np.complex64, copy=False) # [F, T]
|
| 216 |
|
|
@@ -223,12 +218,10 @@ def _postprocess_spec(spec_e: np.ndarray, win_len: int, hop: int, wnorm: float,
|
|
| 223 |
length=None,
|
| 224 |
).astype(np.float32, copy=False)
|
| 225 |
|
| 226 |
-
|
| 227 |
-
waveform_e = np.concatenate(
|
| 228 |
[waveform_e[win_len * 2 :], np.zeros(win_len * 2, dtype=np.float32)],
|
| 229 |
axis=0,
|
| 230 |
)
|
| 231 |
-
return waveform_e
|
| 232 |
|
| 233 |
|
| 234 |
# -----------------------------
|
|
@@ -253,8 +246,8 @@ def enhance_audio_onnx(
|
|
| 253 |
out_state_name = outputs[1].name
|
| 254 |
|
| 255 |
waveform = np.asarray(audio_mono, dtype=np.float32).reshape(-1)
|
| 256 |
-
win_len, hop,
|
| 257 |
-
spec_r_np = _preprocess_waveform(waveform, win_len=win_len, hop=hop,
|
| 258 |
|
| 259 |
state = _load_initial_state(model_key, sess).copy()
|
| 260 |
spec_e_frames = []
|
|
@@ -272,7 +265,7 @@ def enhance_audio_onnx(
|
|
| 272 |
return waveform
|
| 273 |
|
| 274 |
spec_e_np = np.concatenate(spec_e_frames, axis=1)
|
| 275 |
-
waveform_e = _postprocess_spec(spec_e_np, win_len=win_len, hop=hop,
|
| 276 |
return np.asarray(waveform_e, dtype=np.float32).reshape(-1)
|
| 277 |
|
| 278 |
|
|
@@ -428,18 +421,26 @@ CSS = """
|
|
| 428 |
border-radius: 16px;
|
| 429 |
border: 1px solid rgba(0,0,0,0.08);
|
| 430 |
background: linear-gradient(135deg, rgba(255,152,0,0.14), rgba(255,152,0,0.04));
|
|
|
|
| 431 |
}
|
| 432 |
#header h1{
|
| 433 |
-
margin: 0;
|
| 434 |
font-size: 24px;
|
| 435 |
font-weight: 800;
|
| 436 |
letter-spacing: -0.2px;
|
| 437 |
}
|
| 438 |
#header p{
|
| 439 |
-
margin: 6px
|
|
|
|
| 440 |
color: var(--body-text-color-subdued);
|
| 441 |
-
font-size:
|
| 442 |
-
line-height: 1.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 443 |
}
|
| 444 |
|
| 445 |
.spec img { border-radius: 14px; }
|
|
@@ -454,36 +455,14 @@ CSS = """
|
|
| 454 |
"""
|
| 455 |
|
| 456 |
with gr.Blocks(theme=THEME, css=CSS, title="DPDFNet Speech Enhancement") as demo:
|
| 457 |
-
gr.
|
| 458 |
-
# "
|
| 459 |
-
|
| 460 |
-
|
| 461 |
-
|
| 462 |
-
|
| 463 |
-
|
| 464 |
-
|
| 465 |
-
# </p>
|
| 466 |
-
# </div>
|
| 467 |
-
# """
|
| 468 |
-
"""
|
| 469 |
-
<div id="header" style="text-align: center; margin-bottom: 25px;">
|
| 470 |
-
|
| 471 |
-
<h1 style="margin-bottom: 6px;">DPDFNet Speech Enhancement</h1>
|
| 472 |
-
|
| 473 |
-
<p style="font-size: 14px; letter-spacing: 1px; margin-bottom: 14px; color: #555;">
|
| 474 |
-
Causal • Real-Time • Edge-Ready
|
| 475 |
-
</p>
|
| 476 |
-
|
| 477 |
-
<p style="max-width: 720px; margin: 0 auto; font-size: 15px; line-height: 1.6;">
|
| 478 |
-
DPDFNet extends DeepFilterNet2 with Dual-Path RNN blocks to improve
|
| 479 |
-
long-range temporal and cross-band modeling while preserving low latency.
|
| 480 |
-
Designed for single-channel streaming speech enhancement under challenging noise conditions.
|
| 481 |
-
</p>
|
| 482 |
-
|
| 483 |
-
<hr style="margin-top: 22px; border: none; height: 1px; background: linear-gradient(to right, transparent, #ddd, transparent);">
|
| 484 |
-
|
| 485 |
-
</div>
|
| 486 |
-
"""
|
| 487 |
)
|
| 488 |
|
| 489 |
with gr.Row():
|
|
|
|
| 52 |
"dpdfnet4",
|
| 53 |
"dpdfnet8",
|
| 54 |
"dpdfnet2_48khz_hr",
|
| 55 |
+
"dpdfnet8_48khz_hr",
|
| 56 |
]
|
| 57 |
|
| 58 |
found_paths = {p.stem: p for p in ONNX_DIR.glob("*.onnx") if p.is_file()}
|
|
|
|
| 126 |
return sess
|
| 127 |
|
| 128 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
def _load_initial_state(model_key: str, session: ort.InferenceSession) -> np.ndarray:
|
| 130 |
if model_key in _INIT_STATES:
|
| 131 |
return _INIT_STATES[model_key]
|
| 132 |
|
| 133 |
+
if len(session.get_inputs()) < 2:
|
| 134 |
+
raise gr.Error("Expected streaming ONNX model with two inputs: (spec, state).")
|
|
|
|
|
|
|
|
|
|
| 135 |
|
| 136 |
+
meta = session.get_modelmeta().custom_metadata_map
|
| 137 |
+
try:
|
| 138 |
+
state_size = int(meta["state_size"])
|
| 139 |
+
erb_norm_state_size = int(meta["erb_norm_state_size"])
|
| 140 |
+
spec_norm_state_size = int(meta["spec_norm_state_size"])
|
| 141 |
+
erb_norm_init = np.array(
|
| 142 |
+
[float(x) for x in meta["erb_norm_init"].split(",")], dtype=np.float32
|
| 143 |
+
)
|
| 144 |
+
spec_norm_init = np.array(
|
| 145 |
+
[float(x) for x in meta["spec_norm_init"].split(",")], dtype=np.float32
|
| 146 |
+
)
|
| 147 |
+
except KeyError as exc:
|
| 148 |
raise gr.Error(
|
| 149 |
+
f"ONNX model is missing required metadata key: {exc}. "
|
| 150 |
+
"Re-export the model to embed state initialisation metadata."
|
| 151 |
)
|
| 152 |
+
|
| 153 |
+
init_state = np.zeros(state_size, dtype=np.float32)
|
| 154 |
+
init_state[0:erb_norm_state_size] = erb_norm_init
|
| 155 |
+
init_state[erb_norm_state_size:erb_norm_state_size + spec_norm_state_size] = spec_norm_init
|
| 156 |
+
init_state = np.ascontiguousarray(init_state)
|
| 157 |
|
| 158 |
_INIT_STATES[model_key] = init_state
|
| 159 |
return init_state
|
|
|
|
| 170 |
return window.astype(np.float32)
|
| 171 |
|
| 172 |
|
| 173 |
+
def _infer_stft_params(model_key: str, session: ort.InferenceSession) -> Tuple[int, int, np.ndarray]:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
# ONNX spec input is [B, T, F, 2] (or dynamic variants).
|
| 175 |
spec_shape = session.get_inputs()[0].shape
|
| 176 |
freq_bins = spec_shape[-2] if len(spec_shape) >= 2 else None
|
|
|
|
| 184 |
|
| 185 |
hop = win_len // 2
|
| 186 |
win = vorbis_window(win_len)
|
| 187 |
+
return win_len, hop, win
|
|
|
|
| 188 |
|
| 189 |
|
| 190 |
+
def _preprocess_waveform(waveform: np.ndarray, win_len: int, hop: int, win: np.ndarray) -> np.ndarray:
|
| 191 |
audio = np.asarray(waveform, dtype=np.float32).reshape(-1)
|
| 192 |
audio_pad = np.pad(audio, (0, win_len), mode="constant")
|
| 193 |
|
|
|
|
| 200 |
center=True,
|
| 201 |
pad_mode="reflect",
|
| 202 |
)
|
| 203 |
+
spec = spec.T.astype(np.complex64, copy=False) # [T, F]
|
| 204 |
spec_ri = np.stack([spec.real, spec.imag], axis=-1).astype(np.float32, copy=False) # [T, F, 2]
|
| 205 |
+
return np.ascontiguousarray(spec_ri[None, ...], dtype=np.float32) # [1, T, F, 2]
|
| 206 |
|
| 207 |
|
| 208 |
+
def _postprocess_spec(spec_e: np.ndarray, win_len: int, hop: int, win: np.ndarray) -> np.ndarray:
|
| 209 |
spec_c = np.asarray(spec_e[0], dtype=np.float32) # [T, F, 2]
|
| 210 |
spec = (spec_c[..., 0] + 1j * spec_c[..., 1]).T.astype(np.complex64, copy=False) # [F, T]
|
| 211 |
|
|
|
|
| 218 |
length=None,
|
| 219 |
).astype(np.float32, copy=False)
|
| 220 |
|
| 221 |
+
return np.concatenate(
|
|
|
|
| 222 |
[waveform_e[win_len * 2 :], np.zeros(win_len * 2, dtype=np.float32)],
|
| 223 |
axis=0,
|
| 224 |
)
|
|
|
|
| 225 |
|
| 226 |
|
| 227 |
# -----------------------------
|
|
|
|
| 246 |
out_state_name = outputs[1].name
|
| 247 |
|
| 248 |
waveform = np.asarray(audio_mono, dtype=np.float32).reshape(-1)
|
| 249 |
+
win_len, hop, win = _infer_stft_params(model_key, sess)
|
| 250 |
+
spec_r_np = _preprocess_waveform(waveform, win_len=win_len, hop=hop, win=win)
|
| 251 |
|
| 252 |
state = _load_initial_state(model_key, sess).copy()
|
| 253 |
spec_e_frames = []
|
|
|
|
| 265 |
return waveform
|
| 266 |
|
| 267 |
spec_e_np = np.concatenate(spec_e_frames, axis=1)
|
| 268 |
+
waveform_e = _postprocess_spec(spec_e_np, win_len=win_len, hop=hop, win=win)
|
| 269 |
return np.asarray(waveform_e, dtype=np.float32).reshape(-1)
|
| 270 |
|
| 271 |
|
|
|
|
| 421 |
border-radius: 16px;
|
| 422 |
border: 1px solid rgba(0,0,0,0.08);
|
| 423 |
background: linear-gradient(135deg, rgba(255,152,0,0.14), rgba(255,152,0,0.04));
|
| 424 |
+
text-align: center;
|
| 425 |
}
|
| 426 |
#header h1{
|
| 427 |
+
margin: 0 0 6px 0;
|
| 428 |
font-size: 24px;
|
| 429 |
font-weight: 800;
|
| 430 |
letter-spacing: -0.2px;
|
| 431 |
}
|
| 432 |
#header p{
|
| 433 |
+
margin: 6px auto 0 auto;
|
| 434 |
+
max-width: 720px;
|
| 435 |
color: var(--body-text-color-subdued);
|
| 436 |
+
font-size: 14px;
|
| 437 |
+
line-height: 1.6;
|
| 438 |
+
}
|
| 439 |
+
#header hr{
|
| 440 |
+
margin-top: 18px;
|
| 441 |
+
border: none;
|
| 442 |
+
height: 1px;
|
| 443 |
+
background: linear-gradient(to right, transparent, #ddd, transparent);
|
| 444 |
}
|
| 445 |
|
| 446 |
.spec img { border-radius: 14px; }
|
|
|
|
| 455 |
"""
|
| 456 |
|
| 457 |
with gr.Blocks(theme=THEME, css=CSS, title="DPDFNet Speech Enhancement") as demo:
|
| 458 |
+
gr.Markdown(
|
| 459 |
+
"# DPDFNet Speech Enhancement\n\n"
|
| 460 |
+
"Causal · Real-Time · Edge-Ready\n\n"
|
| 461 |
+
"DPDFNet extends DeepFilterNet2 with Dual-Path RNN blocks to improve "
|
| 462 |
+
"long-range temporal and cross-band modeling while preserving low latency. "
|
| 463 |
+
"Designed for single-channel streaming speech enhancement under challenging noise conditions.\n\n"
|
| 464 |
+
"---",
|
| 465 |
+
elem_id="header",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 466 |
)
|
| 467 |
|
| 468 |
with gr.Row():
|