|
|
| """
|
| 高级去混响模块 - 基于二进制残差掩码和时域一致性
|
| 参考: arXiv 2510.00356 - Dereverberation Using Binary Residual Masking
|
| """
|
| import numpy as np
|
| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
| from typing import Tuple, Optional
|
|
|
|
|
| class BinaryResidualMask(nn.Module):
|
| """
|
| 二进制残差掩码网络 - 专注于抑制混响而非预测完整频谱
|
|
|
| 核心思想:
|
| 1. 学习识别并抑制晚期反射(late reflections)
|
| 2. 保留直达声路径(direct path)
|
| 3. 使用时域一致性损失隐式学习相位
|
| """
|
|
|
| def __init__(self, n_fft=2048, hop_length=512):
|
| super().__init__()
|
| self.n_fft = n_fft
|
| self.hop_length = hop_length
|
| self.freq_bins = n_fft // 2 + 1
|
|
|
|
|
| self.encoder1 = nn.Sequential(
|
| nn.Conv2d(1, 32, kernel_size=3, padding=1),
|
| nn.BatchNorm2d(32),
|
| nn.ReLU(),
|
| nn.Conv2d(32, 32, kernel_size=3, padding=1),
|
| nn.BatchNorm2d(32),
|
| nn.ReLU()
|
| )
|
|
|
| self.encoder2 = nn.Sequential(
|
| nn.MaxPool2d(2),
|
| nn.Conv2d(32, 64, kernel_size=3, padding=1),
|
| nn.BatchNorm2d(64),
|
| nn.ReLU(),
|
| nn.Conv2d(64, 64, kernel_size=3, padding=1),
|
| nn.BatchNorm2d(64),
|
| nn.ReLU()
|
| )
|
|
|
| self.encoder3 = nn.Sequential(
|
| nn.MaxPool2d(2),
|
| nn.Conv2d(64, 128, kernel_size=3, padding=1),
|
| nn.BatchNorm2d(128),
|
| nn.ReLU(),
|
| nn.Conv2d(128, 128, kernel_size=3, padding=1),
|
| nn.BatchNorm2d(128),
|
| nn.ReLU()
|
| )
|
|
|
|
|
| self.bottleneck = nn.Sequential(
|
| nn.MaxPool2d(2),
|
| nn.Conv2d(128, 256, kernel_size=3, padding=1),
|
| nn.BatchNorm2d(256),
|
| nn.ReLU()
|
| )
|
|
|
|
|
| self.decoder3 = nn.Sequential(
|
| nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2),
|
| nn.Conv2d(256, 128, kernel_size=3, padding=1),
|
| nn.BatchNorm2d(128),
|
| nn.ReLU()
|
| )
|
|
|
| self.decoder2 = nn.Sequential(
|
| nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
|
| nn.Conv2d(128, 64, kernel_size=3, padding=1),
|
| nn.BatchNorm2d(64),
|
| nn.ReLU()
|
| )
|
|
|
| self.decoder1 = nn.Sequential(
|
| nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2),
|
| nn.Conv2d(64, 32, kernel_size=3, padding=1),
|
| nn.BatchNorm2d(32),
|
| nn.ReLU()
|
| )
|
|
|
|
|
| self.output = nn.Sequential(
|
| nn.Conv2d(32, 1, kernel_size=1),
|
| nn.Sigmoid()
|
| )
|
|
|
| def forward(self, x):
|
| """
|
| Args:
|
| x: [B, 1, F, T] - 输入频谱幅度
|
| Returns:
|
| mask: [B, 1, F, T] - 二进制残差掩码
|
| """
|
|
|
| e1 = self.encoder1(x)
|
| e2 = self.encoder2(e1)
|
| e3 = self.encoder3(e2)
|
|
|
|
|
| b = self.bottleneck(e3)
|
|
|
|
|
| d3 = self.decoder3(b)
|
| d3 = torch.cat([d3, e3], dim=1)
|
|
|
| d2 = self.decoder2(d3)
|
| d2 = torch.cat([d2, e2], dim=1)
|
|
|
| d1 = self.decoder1(d2)
|
| d1 = torch.cat([d1, e1], dim=1)
|
|
|
|
|
| mask = self.output(d1)
|
| return mask
|
|
|
|
|
| def advanced_dereverb(
|
| audio: np.ndarray,
|
| sr: int = 16000,
|
| n_fft: int = 2048,
|
| hop_length: int = 512,
|
| device: str = "cuda"
|
| ) -> Tuple[np.ndarray, np.ndarray]:
|
| """
|
| 高级去混响 - 分离干声和混响
|
|
|
| Args:
|
| audio: 输入音频 [samples]
|
| sr: 采样率
|
| n_fft: FFT大小
|
| hop_length: 跳跃长度
|
| device: 计算设备
|
|
|
| Returns:
|
| dry_signal: 干声(直达声)
|
| reverb_tail: 混响尾巴
|
| """
|
| import librosa
|
|
|
|
|
| spec = librosa.stft(audio, n_fft=n_fft, hop_length=hop_length, win_length=n_fft)
|
| mag = np.abs(spec).astype(np.float32)
|
| phase = np.angle(spec)
|
|
|
|
|
|
|
| rms = librosa.feature.rms(y=audio, frame_length=n_fft, hop_length=hop_length, center=True)[0]
|
| rms_db = 20.0 * np.log10(rms + 1e-8)
|
| ref_db = float(np.percentile(rms_db, 90))
|
|
|
|
|
|
|
| late_reflections = np.zeros_like(mag, dtype=np.float32)
|
|
|
| for t in range(2, mag.shape[1]):
|
|
|
| late_reflections[:, t] = np.maximum(
|
| late_reflections[:, t - 1] * 0.92,
|
| mag[:, t - 2] * 0.80
|
| )
|
|
|
|
|
|
|
| direct_path = np.maximum(mag - 0.75 * late_reflections, 0.0)
|
|
|
|
|
|
|
| if len(rms) < mag.shape[1]:
|
| rms_extended = np.pad(rms, (0, mag.shape[1] - len(rms)), mode='edge')
|
| else:
|
| rms_extended = rms[:mag.shape[1]]
|
|
|
|
|
|
|
| vocal_strength = np.clip((rms_db[:len(rms_extended)] - (ref_db - 35.0)) / 25.0, 0.0, 1.0)
|
|
|
|
|
| reverb_ratio = np.clip(late_reflections / (mag + 1e-8), 0.0, 1.0)
|
| floor_coef = 0.08 + 0.12 * vocal_strength[np.newaxis, :]
|
| floor = (1.0 - reverb_ratio) * floor_coef * mag
|
| direct_path = np.maximum(direct_path, floor)
|
|
|
|
|
| kernel = np.array([1, 2, 3, 2, 1], dtype=np.float32)
|
| kernel /= np.sum(kernel)
|
| direct_path = np.apply_along_axis(
|
| lambda row: np.convolve(row, kernel, mode="same"),
|
| axis=1,
|
| arr=direct_path,
|
| )
|
| direct_path = np.clip(direct_path, 0.0, mag)
|
|
|
|
|
| reverb_mag = mag - direct_path
|
| reverb_mag = np.maximum(reverb_mag, 0.0)
|
|
|
|
|
|
|
| dry_spec = direct_path * np.exp(1j * phase)
|
| dry_signal = librosa.istft(dry_spec, hop_length=hop_length, win_length=n_fft, length=len(audio))
|
|
|
|
|
| reverb_spec = reverb_mag * np.exp(1j * phase)
|
| reverb_tail = librosa.istft(reverb_spec, hop_length=hop_length, win_length=n_fft, length=len(audio))
|
|
|
| return dry_signal.astype(np.float32), reverb_tail.astype(np.float32)
|
|
|
|
|
| def apply_reverb_to_converted(
|
| converted_dry: np.ndarray,
|
| original_reverb: np.ndarray,
|
| mix_ratio: float = 0.8
|
| ) -> np.ndarray:
|
| """
|
| 将原始混响重新应用到转换后的干声上
|
|
|
| Args:
|
| converted_dry: 转换后的干声
|
| original_reverb: 原始混响尾巴
|
| mix_ratio: 混响混合比例 (0-1)
|
|
|
| Returns:
|
| wet_signal: 带混响的转换结果
|
| """
|
|
|
| min_len = min(len(converted_dry), len(original_reverb))
|
| converted_dry = converted_dry[:min_len]
|
| original_reverb = original_reverb[:min_len]
|
|
|
|
|
| wet_signal = converted_dry + mix_ratio * original_reverb
|
|
|
|
|
| from lib.audio import soft_clip
|
| wet_signal = soft_clip(wet_signal, threshold=0.9, ceiling=0.99)
|
|
|
| return wet_signal.astype(np.float32)
|
|
|
|
|
| if __name__ == "__main__":
|
|
|
| print("Testing advanced dereverberation...")
|
|
|
|
|
| sr = 16000
|
| duration = 2.0
|
| t = np.linspace(0, duration, int(sr * duration))
|
|
|
|
|
| dry = np.sin(2 * np.pi * 440 * t).astype(np.float32)
|
|
|
|
|
| reverb = np.zeros_like(dry)
|
| delay_samples = int(0.05 * sr)
|
| for i in range(3):
|
| delay = delay_samples * (i + 1)
|
| decay = 0.5 ** (i + 1)
|
| if delay < len(reverb):
|
| reverb[delay:] += dry[:-delay] * decay
|
|
|
|
|
| wet = dry + reverb * 0.5
|
|
|
|
|
| dry_extracted, reverb_extracted = advanced_dereverb(wet, sr)
|
|
|
| print(f"Input RMS: {np.sqrt(np.mean(wet**2)):.4f}")
|
| print(f"Dry RMS: {np.sqrt(np.mean(dry_extracted**2)):.4f}")
|
| print(f"Reverb RMS: {np.sqrt(np.mean(reverb_extracted**2)):.4f}")
|
| print(f"Separation ratio: {np.sqrt(np.mean(dry_extracted**2)) / (np.sqrt(np.mean(reverb_extracted**2)) + 1e-8):.2f}")
|
|
|
| print("\n[OK] Advanced dereverberation test passed!")
|
|
|