mason369 commited on
Commit
b6f9c90
·
verified ·
1 Parent(s): b15e31b

Upload folder using huggingface_hub

Browse files
Files changed (7) hide show
  1. lib/__init__.py +8 -0
  2. lib/audio.py +139 -0
  3. lib/device.py +186 -0
  4. lib/logger.py +254 -0
  5. lib/mixer.py +214 -0
  6. lib/vocal_cleanup.py +253 -0
  7. lib/vocoder_fix.py +385 -0
lib/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ 核心库模块
4
+ """
5
+ from .audio import load_audio, save_audio
6
+ from .device import get_device, get_device_info, empty_device_cache, supports_fp16
7
+
8
+ __all__ = ["load_audio", "save_audio", "get_device", "get_device_info", "empty_device_cache", "supports_fp16"]
lib/audio.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ 音频处理模块 - 加载、保存和处理音频文件
4
+ """
5
+ import numpy as np
6
+ import librosa
7
+ import soundfile as sf
8
+ from typing import Tuple, Optional
9
+
10
+
11
+ def load_audio(path: str, sr: int = 16000) -> np.ndarray:
12
+ """
13
+ 加载音频文件并重采样
14
+
15
+ Args:
16
+ path: 音频文件路径
17
+ sr: 目标采样率 (默认 16000)
18
+
19
+ Returns:
20
+ np.ndarray: 音频数据 (float32, 单声道)
21
+ """
22
+ audio, orig_sr = librosa.load(path, sr=None, mono=True)
23
+
24
+ if orig_sr != sr:
25
+ audio = librosa.resample(audio, orig_sr=orig_sr, target_sr=sr)
26
+
27
+ return audio.astype(np.float32)
28
+
29
+
30
+ def save_audio(path: str, audio: np.ndarray, sr: int = 48000):
31
+ """
32
+ 保存音频到文件
33
+
34
+ Args:
35
+ path: 输出文件路径
36
+ audio: 音频数据
37
+ sr: 采样率 (默认 48000)
38
+ """
39
+ # 确保音频在 [-1, 1] 范围内
40
+ audio = np.clip(audio, -1.0, 1.0)
41
+ sf.write(path, audio, sr)
42
+
43
+
44
+ def soft_clip(
45
+ audio: np.ndarray,
46
+ threshold: float = 0.9,
47
+ ceiling: float = 0.99,
48
+ ) -> np.ndarray:
49
+ """
50
+ 使用平滑软削波抑制峰值,尽量保留主体响度。
51
+
52
+ Args:
53
+ audio: 输入音频
54
+ threshold: 开始压缩的阈值
55
+ ceiling: 软削波上限
56
+
57
+ Returns:
58
+ np.ndarray: 处理后的音频
59
+ """
60
+ audio = np.asarray(audio, dtype=np.float32)
61
+
62
+ if threshold <= 0:
63
+ raise ValueError("threshold 必须大于 0")
64
+ if ceiling <= threshold:
65
+ raise ValueError("ceiling 必须大于 threshold")
66
+
67
+ result = audio.copy()
68
+ abs_audio = np.abs(result)
69
+ mask = abs_audio > threshold
70
+ if not np.any(mask):
71
+ return result
72
+
73
+ overshoot = (abs_audio[mask] - threshold) / (ceiling - threshold + 1e-8)
74
+ compressed = threshold + (ceiling - threshold) * np.tanh(overshoot)
75
+ result[mask] = np.sign(result[mask]) * compressed
76
+ return result.astype(np.float32, copy=False)
77
+
78
+
79
+ def soft_clip_array(
80
+ audio: np.ndarray,
81
+ threshold: float = 0.9,
82
+ ceiling: float = 0.99,
83
+ ) -> np.ndarray:
84
+ """软削波数组版本,支持单声道/多声道。"""
85
+ return soft_clip(audio, threshold=threshold, ceiling=ceiling)
86
+
87
+
88
+ def get_audio_info(path: str) -> dict:
89
+ """
90
+ 获取音频文件信息
91
+
92
+ Args:
93
+ path: 音频文件路径
94
+
95
+ Returns:
96
+ dict: 音频信息
97
+ """
98
+ info = sf.info(path)
99
+ return {
100
+ "duration": info.duration,
101
+ "sample_rate": info.samplerate,
102
+ "channels": info.channels,
103
+ "format": info.format
104
+ }
105
+
106
+
107
+ def normalize_audio(audio: np.ndarray, target_db: float = -20.0) -> np.ndarray:
108
+ """
109
+ 音频响度归一化
110
+
111
+ Args:
112
+ audio: 输入音频
113
+ target_db: 目标响度 (dB)
114
+
115
+ Returns:
116
+ np.ndarray: 归一化后的音频
117
+ """
118
+ rms = np.sqrt(np.mean(audio ** 2))
119
+ if rms > 0:
120
+ target_rms = 10 ** (target_db / 20)
121
+ audio = audio * (target_rms / rms)
122
+ return np.clip(audio, -1.0, 1.0)
123
+
124
+
125
+ def trim_silence(audio: np.ndarray, sr: int = 16000,
126
+ top_db: int = 30) -> np.ndarray:
127
+ """
128
+ 去除音频首尾静音
129
+
130
+ Args:
131
+ audio: 输入音频
132
+ sr: 采样率
133
+ top_db: 静音阈值 (dB)
134
+
135
+ Returns:
136
+ np.ndarray: 去除静音后的音频
137
+ """
138
+ trimmed, _ = librosa.effects.trim(audio, top_db=top_db)
139
+ return trimmed
lib/device.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ 设备检测模块 - 自动检测并选择最佳计算设备
4
+ 支持: CUDA (NVIDIA / AMD ROCm), XPU (Intel Arc via IPEX), DirectML, MPS (Apple), CPU
5
+ """
6
+ import torch
7
+
8
+
9
+ def _has_xpu() -> bool:
10
+ """检测 Intel XPU (需要 intel_extension_for_pytorch)"""
11
+ try:
12
+ import intel_extension_for_pytorch # noqa: F401
13
+ return hasattr(torch, "xpu") and torch.xpu.is_available()
14
+ except ImportError:
15
+ return False
16
+
17
+
18
+ def _has_directml() -> bool:
19
+ """检测 DirectML (AMD/Intel on Windows)"""
20
+ try:
21
+ import torch_directml # noqa: F401
22
+ return True
23
+ except ImportError:
24
+ return False
25
+
26
+
27
+ def _has_mps() -> bool:
28
+ """检测 Apple MPS"""
29
+ if not hasattr(torch.backends, "mps") or not torch.backends.mps.is_available():
30
+ return False
31
+ try:
32
+ torch.zeros(1).to(torch.device("mps"))
33
+ return True
34
+ except Exception:
35
+ return False
36
+
37
+
38
+ def _is_rocm() -> bool:
39
+ """检测当前 PyTorch 是否为 ROCm 构建 (AMD GPU)"""
40
+ return hasattr(torch.version, "hip") and torch.version.hip is not None
41
+
42
+
43
+ def get_device(preferred: str = "cuda") -> torch.device:
44
+ """
45
+ 获取计算设备,按优先级自动回退
46
+
47
+ Args:
48
+ preferred: 首选设备 ("cuda", "xpu", "directml", "mps", "cpu")
49
+
50
+ Returns:
51
+ torch.device: 可用的计算设备
52
+ """
53
+ p = preferred.lower().strip()
54
+
55
+ # 精确匹配请求
56
+ if p in ("cuda", "cuda:0") and torch.cuda.is_available():
57
+ return torch.device("cuda")
58
+ if p in ("xpu", "xpu:0") and _has_xpu():
59
+ return torch.device("xpu")
60
+ if (p == "directml" or p.startswith("privateuseone")) and _has_directml():
61
+ import torch_directml
62
+ return torch_directml.device(torch_directml.default_device())
63
+ if p == "mps" and _has_mps():
64
+ return torch.device("mps")
65
+ if p == "cpu":
66
+ return torch.device("cpu")
67
+
68
+ # 自动检测: CUDA (含 ROCm) > XPU > DirectML > MPS > CPU
69
+ if torch.cuda.is_available():
70
+ return torch.device("cuda")
71
+ if _has_xpu():
72
+ return torch.device("xpu")
73
+ if _has_directml():
74
+ import torch_directml
75
+ return torch_directml.device(torch_directml.default_device())
76
+ if _has_mps():
77
+ return torch.device("mps")
78
+ return torch.device("cpu")
79
+
80
+
81
+ def supports_fp16(device: torch.device) -> bool:
82
+ """判断设备是否支持 FP16 推理"""
83
+ dtype = str(device.type) if hasattr(device, "type") else str(device)
84
+ if dtype == "cuda":
85
+ return True # CUDA (含 ROCm) 均支持
86
+ if dtype == "xpu":
87
+ return True
88
+ # DirectML / MPS / CPU 不稳定,默认关闭
89
+ return False
90
+
91
+
92
+ def empty_device_cache(device: torch.device = None):
93
+ """清理设备显存缓存(设备无关)"""
94
+ if device is not None:
95
+ dtype = str(device.type) if hasattr(device, "type") else str(device)
96
+ else:
97
+ dtype = None
98
+
99
+ if (dtype is None or dtype == "cuda") and torch.cuda.is_available():
100
+ torch.cuda.empty_cache()
101
+ if (dtype is None or dtype == "xpu") and _has_xpu():
102
+ torch.xpu.empty_cache()
103
+ if (dtype is None or dtype == "mps") and _has_mps():
104
+ if hasattr(torch.mps, "empty_cache"):
105
+ torch.mps.empty_cache()
106
+
107
+
108
+ def get_device_info() -> dict:
109
+ """获取设备详细信息"""
110
+ info = {
111
+ "backends": [],
112
+ "current_device": "cpu",
113
+ "devices": []
114
+ }
115
+
116
+ # CUDA (NVIDIA 或 AMD ROCm)
117
+ if torch.cuda.is_available():
118
+ backend = "ROCm (AMD)" if _is_rocm() else "CUDA (NVIDIA)"
119
+ info["backends"].append(backend)
120
+ info["current_device"] = "cuda"
121
+ for i in range(torch.cuda.device_count()):
122
+ props = torch.cuda.get_device_properties(i)
123
+ info["devices"].append({
124
+ "index": i,
125
+ "backend": backend,
126
+ "name": props.name,
127
+ "total_memory_gb": round(props.total_memory / (1024**3), 2),
128
+ })
129
+
130
+ # Intel XPU
131
+ if _has_xpu():
132
+ info["backends"].append("XPU (Intel)")
133
+ if not info["devices"]:
134
+ info["current_device"] = "xpu"
135
+ for i in range(torch.xpu.device_count()):
136
+ props = torch.xpu.get_device_properties(i)
137
+ info["devices"].append({
138
+ "index": i,
139
+ "backend": "XPU (Intel)",
140
+ "name": props.name,
141
+ "total_memory_gb": round(props.total_memory / (1024**3), 2),
142
+ })
143
+
144
+ # DirectML
145
+ if _has_directml():
146
+ import torch_directml
147
+ info["backends"].append("DirectML")
148
+ if not info["devices"]:
149
+ info["current_device"] = "directml"
150
+ info["devices"].append({
151
+ "index": 0,
152
+ "backend": "DirectML",
153
+ "name": torch_directml.device_name(0),
154
+ "total_memory_gb": None,
155
+ })
156
+
157
+ # MPS
158
+ if _has_mps():
159
+ info["backends"].append("MPS (Apple)")
160
+ if not info["devices"]:
161
+ info["current_device"] = "mps"
162
+
163
+ if not info["backends"]:
164
+ info["backends"].append("CPU")
165
+
166
+ return info
167
+
168
+
169
+ def print_device_info():
170
+ """打印设备信息到控制台"""
171
+ info = get_device_info()
172
+
173
+ print("=" * 50)
174
+ print("设备信息")
175
+ print("=" * 50)
176
+ print(f"可用后端: {', '.join(info['backends'])}")
177
+ print(f"当前设备: {info['current_device']}")
178
+
179
+ for dev in info["devices"]:
180
+ mem = f"{dev['total_memory_gb']} GB" if dev.get("total_memory_gb") else "N/A"
181
+ print(f" [{dev['index']}] {dev['name']} ({dev['backend']}) - 显存: {mem}")
182
+
183
+ if not info["devices"]:
184
+ print(" 无 GPU 设备,将使用 CPU 进行推理")
185
+
186
+ print("=" * 50)
lib/logger.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ 日志工具模块 - 支持时间戳和颜色输出
4
+ """
5
+ import sys
6
+ import logging
7
+ from datetime import datetime
8
+
9
+ try:
10
+ from colorama import init, Fore, Style, Back
11
+ init(autoreset=True) # 初始化 colorama (Windows 兼容), autoreset确保每行重置
12
+ COLORAMA_AVAILABLE = True
13
+ except ImportError:
14
+ COLORAMA_AVAILABLE = False
15
+ # 定义空的占位符
16
+ class Fore:
17
+ LIGHTBLACK_EX = GREEN = YELLOW = RED = CYAN = BLUE = MAGENTA = WHITE = LIGHTGREEN_EX = LIGHTCYAN_EX = LIGHTYELLOW_EX = LIGHTMAGENTA_EX = ""
18
+ class Style:
19
+ RESET_ALL = BRIGHT = DIM = ""
20
+ class Back:
21
+ pass
22
+
23
+
24
+ class Logger:
25
+ """统一日志工具"""
26
+
27
+ SAFE_CHAR_MAP = {
28
+ "✓": "[OK] ",
29
+ "✗": "[X] ",
30
+ "→": "->",
31
+ "◆": "*",
32
+ }
33
+
34
+ COLORS = {
35
+ "DEBUG": Fore.LIGHTBLACK_EX,
36
+ "INFO": Fore.GREEN,
37
+ "SUCCESS": Fore.LIGHTGREEN_EX,
38
+ "WARNING": Fore.YELLOW,
39
+ "ERROR": Fore.RED,
40
+ "STEP": Fore.CYAN,
41
+ "DETAIL": Fore.LIGHTCYAN_EX,
42
+ "PROGRESS": Fore.MAGENTA,
43
+ "MODEL": Fore.LIGHTMAGENTA_EX,
44
+ "AUDIO": Fore.BLUE,
45
+ "CONFIG": Fore.LIGHTYELLOW_EX,
46
+ }
47
+
48
+ RESET = Style.RESET_ALL
49
+ BRIGHT = Style.BRIGHT
50
+ DIM = Style.DIM
51
+
52
+ # 详细日志开关
53
+ verbose = True
54
+
55
+ @staticmethod
56
+ def _sanitize_console_text(text: str) -> str:
57
+ """将不兼容当前终端编码的字符替换为安全文本。"""
58
+ sanitized = text
59
+ for src, dst in Logger.SAFE_CHAR_MAP.items():
60
+ sanitized = sanitized.replace(src, dst)
61
+ return sanitized
62
+
63
+ @staticmethod
64
+ def _emit(text: str):
65
+ """安全输出到终端,避免 Windows/GBK 控制台因 Unicode 崩溃。"""
66
+ try:
67
+ print(text, flush=True)
68
+ return
69
+ except UnicodeEncodeError:
70
+ pass
71
+
72
+ fallback = Logger._sanitize_console_text(text)
73
+ encoding = getattr(sys.stdout, "encoding", None) or "utf-8"
74
+ try:
75
+ print(
76
+ fallback.encode(encoding, errors="replace").decode(encoding),
77
+ flush=True,
78
+ )
79
+ except Exception:
80
+ print(
81
+ fallback.encode("ascii", errors="replace").decode("ascii"),
82
+ flush=True,
83
+ )
84
+
85
+ @staticmethod
86
+ def _log(level: str, msg: str, force_print: bool = True):
87
+ """内部日志方法"""
88
+ timestamp = datetime.now().strftime("%H:%M:%S")
89
+ color = Logger.COLORS.get(level, "")
90
+ reset = Logger.RESET
91
+
92
+ # 根据级别决定前缀
93
+ if level in ("INFO", "STEP", "SUCCESS"):
94
+ prefix = ""
95
+ elif level == "DETAIL":
96
+ prefix = " → "
97
+ elif level == "PROGRESS":
98
+ prefix = " ◆ "
99
+ elif level == "MODEL":
100
+ prefix = "[模型] "
101
+ elif level == "AUDIO":
102
+ prefix = "[音频] "
103
+ elif level == "CONFIG":
104
+ prefix = "[配置] "
105
+ else:
106
+ prefix = f"[{level}] "
107
+
108
+ output = f"{color}[{timestamp}]{prefix}{msg}{reset}"
109
+ Logger._emit(output)
110
+
111
+ @staticmethod
112
+ def debug(msg: str):
113
+ """调试日志 (灰色) - 仅在verbose模式下显示"""
114
+ if Logger.verbose:
115
+ Logger._log("DEBUG", msg)
116
+
117
+ @staticmethod
118
+ def info(msg: str):
119
+ """信息日志 (绿色)"""
120
+ Logger._log("INFO", msg)
121
+
122
+ @staticmethod
123
+ def success(msg: str):
124
+ """成功日志 (亮绿色)"""
125
+ Logger._log("SUCCESS", f"✓ {msg}")
126
+
127
+ @staticmethod
128
+ def warning(msg: str):
129
+ """警告日志 (黄色)"""
130
+ Logger._log("WARNING", msg)
131
+
132
+ @staticmethod
133
+ def error(msg: str):
134
+ """错误日志 (红色)"""
135
+ Logger._log("ERROR", msg)
136
+
137
+ @staticmethod
138
+ def step(current: int, total: int, msg: str):
139
+ """步骤日志 (青色)"""
140
+ timestamp = datetime.now().strftime("%H:%M:%S")
141
+ color = Logger.COLORS.get("STEP", "")
142
+ reset = Logger.RESET
143
+ Logger._emit(f"{color}[{timestamp}][{current}/{total}] {msg}{reset}")
144
+
145
+ @staticmethod
146
+ def detail(msg: str):
147
+ """详细日志 (浅青色) - 用于显示处理细节"""
148
+ if Logger.verbose:
149
+ Logger._log("DETAIL", msg)
150
+
151
+ @staticmethod
152
+ def progress(msg: str):
153
+ """进度日志 (紫色) - 用于显示处理进度"""
154
+ Logger._log("PROGRESS", msg)
155
+
156
+ @staticmethod
157
+ def model(msg: str):
158
+ """模型日志 (浅紫色) - 用于模型加载/卸载信息"""
159
+ Logger._log("MODEL", msg)
160
+
161
+ @staticmethod
162
+ def audio(msg: str):
163
+ """音频日志 (蓝色) - 用于音频处理信息"""
164
+ Logger._log("AUDIO", msg)
165
+
166
+ @staticmethod
167
+ def config(msg: str):
168
+ """配置日志 (浅黄色) - 用于配置信息"""
169
+ if Logger.verbose:
170
+ Logger._log("CONFIG", msg)
171
+
172
+ @staticmethod
173
+ def header(msg: str):
174
+ """标题日志 (带分隔线)"""
175
+ timestamp = datetime.now().strftime("%H:%M:%S")
176
+ color = Logger.COLORS.get("INFO", "")
177
+ reset = Logger.RESET
178
+ Logger._emit(f"{color}[{timestamp}] {'=' * 50}{reset}")
179
+ Logger._emit(f"{color}[{timestamp}] {msg}{reset}")
180
+ Logger._emit(f"{color}[{timestamp}] {'=' * 50}{reset}")
181
+
182
+ @staticmethod
183
+ def separator(char: str = "-", length: int = 40):
184
+ """分隔线"""
185
+ timestamp = datetime.now().strftime("%H:%M:%S")
186
+ color = Logger.COLORS.get("DEBUG", "")
187
+ reset = Logger.RESET
188
+ Logger._emit(f"{color}[{timestamp}] {char * length}{reset}")
189
+
190
+ @staticmethod
191
+ def set_verbose(enabled: bool):
192
+ """设置详细日志模式"""
193
+ Logger.verbose = enabled
194
+
195
+
196
+ # 便捷实例
197
+ log = Logger()
198
+
199
+
200
+ # ============ 配置标准 logging 模块使用颜色 ============
201
+
202
+ class ColoredFormatter(logging.Formatter):
203
+ """为标准logging模块添加颜色支持"""
204
+
205
+ LEVEL_COLORS = {
206
+ logging.DEBUG: Fore.LIGHTBLACK_EX,
207
+ logging.INFO: Fore.GREEN,
208
+ logging.WARNING: Fore.YELLOW,
209
+ logging.ERROR: Fore.RED,
210
+ logging.CRITICAL: Fore.RED + Style.BRIGHT,
211
+ }
212
+
213
+ def format(self, record):
214
+ # 获取颜色
215
+ color = self.LEVEL_COLORS.get(record.levelno, "")
216
+ reset = Style.RESET_ALL
217
+
218
+ # 格式化时间
219
+ timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
220
+
221
+ # 构建消息
222
+ level_name = record.levelname
223
+ module_name = record.name
224
+
225
+ # 格式化输出
226
+ formatted = f"{color}{timestamp} | {level_name} | {module_name} | {record.getMessage()}{reset}"
227
+ return formatted
228
+
229
+
230
+ def setup_colored_logging(level=logging.INFO):
231
+ """配置全局logging使用颜色输出"""
232
+ # 获取根logger
233
+ root_logger = logging.getLogger()
234
+ root_logger.setLevel(level)
235
+
236
+ # 移除现有的handlers
237
+ for handler in root_logger.handlers[:]:
238
+ root_logger.removeHandler(handler)
239
+
240
+ # 添加带颜色的handler
241
+ console_handler = logging.StreamHandler(sys.stdout)
242
+ console_handler.setLevel(level)
243
+ console_handler.setFormatter(ColoredFormatter())
244
+ root_logger.addHandler(console_handler)
245
+
246
+ return root_logger
247
+
248
+
249
+ # 自动配置logging颜色
250
+ setup_colored_logging(logging.INFO)
251
+
252
+ # 抑制第三方库的英文日志
253
+ logging.getLogger("faiss").setLevel(logging.WARNING)
254
+ logging.getLogger("audio_separator").setLevel(logging.WARNING)
lib/mixer.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ 混音模块 - 人声与伴奏混合
4
+ """
5
+ import numpy as np
6
+ import librosa
7
+ import soundfile as sf
8
+ from pathlib import Path
9
+ from typing import Optional
10
+
11
+ from lib.audio import soft_clip_array
12
+
13
+ try:
14
+ from lib.logger import log
15
+ except ImportError:
16
+ log = None
17
+
18
+ try:
19
+ from pedalboard import Pedalboard, Reverb, Compressor, Gain
20
+ PEDALBOARD_AVAILABLE = True
21
+ except ImportError:
22
+ PEDALBOARD_AVAILABLE = False
23
+
24
+
25
+ def _probe_sample_rate(path: str, fallback: int = 44100) -> int:
26
+ """Probe sample rate from file metadata."""
27
+ try:
28
+ return int(sf.info(path).samplerate)
29
+ except Exception:
30
+ return int(fallback)
31
+
32
+
33
+ def load_audio_for_mix(path: str, target_sr: Optional[int] = None) -> tuple:
34
+ """
35
+ 加载音频用于混音。
36
+
37
+ Args:
38
+ path: 音频路径
39
+ target_sr: 目标采样率;为 None 时保持原始采样率
40
+
41
+ Returns:
42
+ tuple: (audio_data, sample_rate)
43
+ """
44
+ if log:
45
+ log.detail(f"加载音频: {Path(path).name}")
46
+
47
+ audio, sr = librosa.load(path, sr=target_sr, mono=False)
48
+
49
+ if audio.ndim == 1:
50
+ audio = np.stack([audio, audio])
51
+ if log:
52
+ log.detail("单声道已扩展为双声道")
53
+
54
+ if log:
55
+ log.detail(f"音频形状: {audio.shape}, 采样率: {sr}Hz")
56
+
57
+ return audio, sr
58
+
59
+
60
+ def apply_reverb(
61
+ audio: np.ndarray,
62
+ sr: int,
63
+ room_size: float = 0.3,
64
+ wet_level: float = 0.2,
65
+ ) -> np.ndarray:
66
+ """对人声应用混响效果。"""
67
+ if not PEDALBOARD_AVAILABLE:
68
+ if log:
69
+ log.warning("Pedalboard 不可用,跳过混响处理")
70
+ return audio
71
+
72
+ if log:
73
+ log.detail(f"应用混响: room_size={room_size}, wet_level={wet_level}")
74
+
75
+ if audio.ndim == 1:
76
+ audio = audio.reshape(1, -1)
77
+
78
+ board = Pedalboard([
79
+ Reverb(room_size=room_size, wet_level=wet_level, dry_level=1.0 - wet_level)
80
+ ])
81
+ processed = board(audio, sr)
82
+
83
+ if log:
84
+ log.detail("混响处理完成")
85
+
86
+ return processed
87
+
88
+
89
+ def adjust_audio_length(audio: np.ndarray, target_length: int) -> np.ndarray:
90
+ """将音频裁切/补零到目标长度。"""
91
+ current_length = audio.shape[-1]
92
+
93
+ if current_length == target_length:
94
+ return audio
95
+ if current_length > target_length:
96
+ return audio[..., :target_length]
97
+
98
+ pad_amount = target_length - current_length
99
+ if audio.ndim == 1:
100
+ return np.pad(audio, (0, pad_amount))
101
+ return np.pad(audio, ((0, 0), (0, pad_amount)))
102
+
103
+
104
+ def mix_vocals_and_accompaniment(
105
+ vocals_path: str,
106
+ accompaniment_path: str,
107
+ output_path: str,
108
+ vocals_volume: float = 1.0,
109
+ accompaniment_volume: float = 1.0,
110
+ reverb_amount: float = 0.0,
111
+ target_sr: Optional[int] = None,
112
+ ) -> str:
113
+ """
114
+ 混合人声和伴奏。
115
+
116
+ Args:
117
+ vocals_path: 人声音频路径
118
+ accompaniment_path: 伴奏音频路径
119
+ output_path: 输出路径
120
+ vocals_volume: 人声音量 (0-2)
121
+ accompaniment_volume: 伴奏音量 (0-2)
122
+ reverb_amount: 人声混响量 (0-1)
123
+ target_sr: 目标采样率;None 时自动采用两轨中更高采样率
124
+
125
+ Returns:
126
+ str: 输出文件路径
127
+ """
128
+ if target_sr is None or target_sr <= 0:
129
+ vocals_sr = _probe_sample_rate(vocals_path)
130
+ accompaniment_sr = _probe_sample_rate(accompaniment_path)
131
+ target_sr = max(vocals_sr, accompaniment_sr)
132
+
133
+ if log:
134
+ log.progress("开始混音处理...")
135
+ log.audio(f"人声文件: {Path(vocals_path).name}")
136
+ log.audio(f"伴奏文件: {Path(accompaniment_path).name}")
137
+ log.config(f"人声音量: {vocals_volume}, 伴奏音量: {accompaniment_volume}")
138
+ log.config(f"混响量: {reverb_amount}, 目标采样率: {target_sr}Hz")
139
+
140
+ if log:
141
+ log.detail("加载人声音频...")
142
+ vocals, sr = load_audio_for_mix(vocals_path, target_sr)
143
+
144
+ if log:
145
+ log.detail("加载伴奏音频...")
146
+ accompaniment, _ = load_audio_for_mix(accompaniment_path, target_sr)
147
+
148
+ if reverb_amount > 0 and PEDALBOARD_AVAILABLE:
149
+ if log:
150
+ log.progress("应用人声混响...")
151
+ vocals = apply_reverb(vocals, sr, room_size=0.4, wet_level=reverb_amount)
152
+ elif reverb_amount > 0 and log:
153
+ log.warning("Pedalboard 不可用,跳过混响")
154
+
155
+ vocals = soft_clip_array(vocals * vocals_volume, threshold=0.85, ceiling=0.95)
156
+ accompaniment = soft_clip_array(
157
+ accompaniment * accompaniment_volume,
158
+ threshold=0.85,
159
+ ceiling=0.95,
160
+ )
161
+
162
+ vocals_len = vocals.shape[-1]
163
+ accompaniment_len = accompaniment.shape[-1]
164
+ target_len = max(vocals_len, accompaniment_len)
165
+
166
+ if target_len <= 0:
167
+ raise ValueError("混音失败:音频长度为 0")
168
+
169
+ if log:
170
+ log.detail(f"人声长度: {vocals_len}, 伴奏长度: {accompaniment_len}")
171
+ if vocals_len != accompaniment_len:
172
+ log.detail(f"长度不一致,已补齐到最长长度: {target_len}")
173
+
174
+ vocals = adjust_audio_length(vocals, target_len)
175
+ accompaniment = adjust_audio_length(accompaniment, target_len)
176
+
177
+ if log:
178
+ log.progress("混合音轨...")
179
+ mixed = vocals + accompaniment
180
+
181
+ max_val = float(np.max(np.abs(mixed)))
182
+ if log:
183
+ log.detail(f"混合后峰值: {max_val:.4f}")
184
+
185
+ mixed = soft_clip_array(mixed, threshold=0.90, ceiling=0.98)
186
+ if log:
187
+ final_peak = float(np.max(np.abs(mixed)))
188
+ log.detail(f"软削波后峰值: {final_peak:.4f}")
189
+
190
+ if mixed.ndim == 2:
191
+ mixed = mixed.T
192
+
193
+ output_dir = Path(output_path).parent
194
+ output_dir.mkdir(parents=True, exist_ok=True)
195
+
196
+ if log:
197
+ log.progress(f"保存混音文件: {output_path}")
198
+
199
+ sf.write(output_path, mixed, sr)
200
+
201
+ output_size = Path(output_path).stat().st_size
202
+ duration = target_len / sr
203
+
204
+ if log:
205
+ log.success("混音完成")
206
+ log.audio(f"输出时长: {duration:.2f}秒")
207
+ log.audio(f"输出大小: {output_size / 1024 / 1024:.2f} MB")
208
+
209
+ return output_path
210
+
211
+
212
+ def check_pedalboard_available() -> bool:
213
+ """检查 pedalboard 是否可用。"""
214
+ return PEDALBOARD_AVAILABLE
lib/vocal_cleanup.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ 音频后处理模块 - 齿音和呼吸音处理
4
+ 基于研究文献的最佳实践
5
+ """
6
+ import numpy as np
7
+ from scipy import signal
8
+ from typing import Optional
9
+
10
+
11
+ def detect_sibilance_frames(audio: np.ndarray, sr: int, threshold_db: float = -20.0) -> np.ndarray:
12
+ """
13
+ 检测齿音帧 (s, sh, ch, z 等高频辅音)
14
+
15
+ 参考: "Managing Sibilance" - Sound on Sound
16
+ 齿音主要集中在 4-10kHz 频段
17
+
18
+ Args:
19
+ audio: 音频数据
20
+ sr: 采样率
21
+ threshold_db: 高频能量阈值 (dB)
22
+
23
+ Returns:
24
+ 布尔数组,True 表示齿音帧
25
+ """
26
+ # 设计高通滤波器提取高频成分 (4-10kHz)
27
+ nyquist = sr / 2
28
+ low_freq = 4000 / nyquist
29
+ high_freq = min(10000 / nyquist, 0.99)
30
+
31
+ # 带通滤波器
32
+ sos = signal.butter(4, [low_freq, high_freq], btype='band', output='sos')
33
+ high_freq_audio = signal.sosfilt(sos, audio)
34
+
35
+ # 计算帧能量
36
+ frame_length = int(0.02 * sr) # 20ms 帧
37
+ hop_length = int(0.01 * sr) # 10ms 跳跃
38
+
39
+ n_frames = 1 + (len(audio) - frame_length) // hop_length
40
+ high_energy = np.zeros(n_frames)
41
+ total_energy = np.zeros(n_frames)
42
+
43
+ for i in range(n_frames):
44
+ start = i * hop_length
45
+ end = start + frame_length
46
+ if end > len(audio):
47
+ break
48
+
49
+ # 高频能量
50
+ high_energy[i] = np.sum(high_freq_audio[start:end] ** 2)
51
+ # 总能量
52
+ total_energy[i] = np.sum(audio[start:end] ** 2)
53
+
54
+ # 计算高频能量比例
55
+ high_ratio = np.zeros_like(high_energy)
56
+ mask = total_energy > 1e-10
57
+ high_ratio[mask] = high_energy[mask] / total_energy[mask]
58
+
59
+ # 转换为 dB
60
+ high_energy_db = 10 * np.log10(high_energy + 1e-10)
61
+
62
+ # 齿音检测:高频能量高且高频比例大
63
+ is_sibilance = (high_energy_db > threshold_db) & (high_ratio > 0.3)
64
+
65
+ return is_sibilance
66
+
67
+
68
+ def reduce_sibilance(audio: np.ndarray, sr: int, reduction_db: float = 6.0) -> np.ndarray:
69
+ """
70
+ 减少齿音 (De-essing)
71
+
72
+ 参考: "Advanced Sibilance Control" - Mike's Mix Master
73
+ 使用多频段动态压缩技术
74
+
75
+ Args:
76
+ audio: 音频数据
77
+ sr: 采样率
78
+ reduction_db: 齿音衰减量 (dB)
79
+
80
+ Returns:
81
+ 处理后的音频
82
+ """
83
+ # 检测齿音帧
84
+ sibilance_frames = detect_sibilance_frames(audio, sr)
85
+
86
+ if not np.any(sibilance_frames):
87
+ return audio
88
+
89
+ # 计算衰减增益曲线(在时域应用,避免频段分离的相位问题)
90
+ frame_length = int(0.02 * sr)
91
+ hop_length = int(0.01 * sr)
92
+
93
+ gain_curve = np.ones(len(audio))
94
+ reduction_factor = 10 ** (-reduction_db / 20)
95
+
96
+ for i, is_sib in enumerate(sibilance_frames):
97
+ if is_sib:
98
+ start = i * hop_length
99
+ end = start + frame_length
100
+ if end > len(audio):
101
+ break
102
+
103
+ # 平滑过渡
104
+ fade_in = np.linspace(1.0, reduction_factor, frame_length // 4)
105
+ sustain = np.full(frame_length // 2, reduction_factor)
106
+ fade_out = np.linspace(reduction_factor, 1.0, frame_length // 4)
107
+ envelope = np.concatenate([fade_in, sustain, fade_out])
108
+
109
+ # 应用增益
110
+ gain_curve[start:start+len(envelope)] = np.minimum(
111
+ gain_curve[start:start+len(envelope)],
112
+ envelope
113
+ )
114
+
115
+ # 直接在时域应用增益(避免频段分离)
116
+ result = audio * gain_curve
117
+
118
+ return result
119
+
120
+
121
+ def detect_breath_frames(audio: np.ndarray, sr: int, threshold_db: float = -40.0) -> np.ndarray:
122
+ """
123
+ 检测呼吸音帧
124
+
125
+ 呼吸音特征:
126
+ - 低能量
127
+ - 宽频噪声
128
+ - 通常在乐句之间
129
+
130
+ Args:
131
+ audio: 音频数据
132
+ sr: 采样率
133
+ threshold_db: 能量阈值 (dB)
134
+
135
+ Returns:
136
+ 布尔数组,True 表示呼吸音帧
137
+ """
138
+ frame_length = int(0.02 * sr) # 20ms
139
+ hop_length = int(0.01 * sr) # 10ms
140
+
141
+ n_frames = 1 + (len(audio) - frame_length) // hop_length
142
+ is_breath = np.zeros(n_frames, dtype=bool)
143
+
144
+ for i in range(n_frames):
145
+ start = i * hop_length
146
+ end = start + frame_length
147
+ if end > len(audio):
148
+ break
149
+
150
+ frame = audio[start:end]
151
+
152
+ # 计算能量
153
+ energy = np.sum(frame ** 2)
154
+ energy_db = 10 * np.log10(energy + 1e-10)
155
+
156
+ # 计算频谱平坦度 (噪声特征)
157
+ fft = np.abs(np.fft.rfft(frame))
158
+ geometric_mean = np.exp(np.mean(np.log(fft + 1e-10)))
159
+ arithmetic_mean = np.mean(fft)
160
+ spectral_flatness = geometric_mean / (arithmetic_mean + 1e-10)
161
+
162
+ # 呼吸音:低能量 + 高频谱平坦度
163
+ is_breath[i] = (energy_db < threshold_db) and (spectral_flatness > 0.5)
164
+
165
+ return is_breath
166
+
167
+
168
+ def reduce_breath_noise(audio: np.ndarray, sr: int, reduction_db: float = 12.0) -> np.ndarray:
169
+ """
170
+ 减少呼吸音噪声
171
+
172
+ 参考: "How to REALLY Clean Vocals" - Waves
173
+
174
+ Args:
175
+ audio: 音频数据
176
+ sr: 采样率
177
+ reduction_db: 呼吸音衰减量 (dB)
178
+
179
+ Returns:
180
+ 处理后的音频
181
+ """
182
+ # 检测呼吸音帧
183
+ breath_frames = detect_breath_frames(audio, sr)
184
+
185
+ if not np.any(breath_frames):
186
+ return audio
187
+
188
+ # 计算衰减增益曲线
189
+ frame_length = int(0.02 * sr)
190
+ hop_length = int(0.01 * sr)
191
+
192
+ gain_curve = np.ones(len(audio))
193
+ reduction_factor = 10 ** (-reduction_db / 20)
194
+
195
+ for i, is_breath in enumerate(breath_frames):
196
+ if is_breath:
197
+ start = i * hop_length
198
+ end = start + frame_length
199
+ if end > len(audio):
200
+ break
201
+
202
+ # 平滑过渡,避免咔嗒声
203
+ fade_length = frame_length // 4
204
+ fade_in = np.linspace(1.0, reduction_factor, fade_length)
205
+ sustain = np.full(frame_length - 2 * fade_length, reduction_factor)
206
+ fade_out = np.linspace(reduction_factor, 1.0, fade_length)
207
+ envelope = np.concatenate([fade_in, sustain, fade_out])
208
+
209
+ # 应用增益
210
+ gain_curve[start:start+len(envelope)] = np.minimum(
211
+ gain_curve[start:start+len(envelope)],
212
+ envelope
213
+ )
214
+
215
+ # 应用增益曲线
216
+ result = audio * gain_curve
217
+
218
+ return result
219
+
220
+
221
+ def apply_vocal_cleanup(
222
+ audio: np.ndarray,
223
+ sr: int,
224
+ reduce_sibilance_enabled: bool = True,
225
+ reduce_breath_enabled: bool = True,
226
+ sibilance_reduction_db: float = 4.0,
227
+ breath_reduction_db: float = 8.0
228
+ ) -> np.ndarray:
229
+ """
230
+ 应用完整的人声清理处理
231
+
232
+ Args:
233
+ audio: 音频数据
234
+ sr: 采样率
235
+ reduce_sibilance_enabled: 是否减少齿音
236
+ reduce_breath_enabled: 是否减少呼吸音
237
+ sibilance_reduction_db: 齿音衰减量 (dB)
238
+ breath_reduction_db: 呼吸音衰减量 (dB)
239
+
240
+ Returns:
241
+ 处理后的音频
242
+ """
243
+ result = audio.copy()
244
+
245
+ # 减少呼吸音(先处理,因为能量更低)
246
+ if reduce_breath_enabled:
247
+ result = reduce_breath_noise(result, sr, breath_reduction_db)
248
+
249
+ # 减少齿音
250
+ if reduce_sibilance_enabled:
251
+ result = reduce_sibilance(result, sr, sibilance_reduction_db)
252
+
253
+ return result
lib/vocoder_fix.py ADDED
@@ -0,0 +1,385 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Vocoder伪影修复 - 针对呼吸音电音和长音撕裂
4
+ 基于RVC社区反馈和研究文献
5
+ """
6
+ import numpy as np
7
+ from scipy import signal
8
+ from typing import Optional
9
+
10
+
11
+ def fix_phase_discontinuity(audio: np.ndarray, sr: int, chunk_boundaries: Optional[list] = None) -> np.ndarray:
12
+ """
13
+ 修复相位不连续导致的撕裂
14
+
15
+ 参考: "Prosody-Guided Harmonic Attention for Phase-Coherent Neural Vocoding" (arXiv:2601.14472)
16
+ Vocoder在长音时会产生相位不连续,导致撕裂
17
+
18
+ Args:
19
+ audio: 音频数据
20
+ sr: 采样率
21
+ chunk_boundaries: 分块边界位置(样本索引)
22
+
23
+ Returns:
24
+ 修复后的音频
25
+ """
26
+ # 使用希尔伯特变换提取瞬时相位
27
+ analytic_signal = signal.hilbert(audio)
28
+ instantaneous_phase = np.unwrap(np.angle(analytic_signal))
29
+ amplitude = np.abs(analytic_signal)
30
+
31
+ # 检测相位跳变
32
+ phase_diff = np.diff(instantaneous_phase)
33
+ phase_diff_threshold = np.percentile(np.abs(phase_diff), 99) * 2.5
34
+
35
+ # 找到相位跳变点
36
+ discontinuities = np.where(np.abs(phase_diff) > phase_diff_threshold)[0]
37
+
38
+ if len(discontinuities) == 0:
39
+ return audio
40
+
41
+ # 修复每个不连续点
42
+ result = audio.copy()
43
+ phase_corrected = instantaneous_phase.copy()
44
+
45
+ for disc_idx in discontinuities:
46
+ # 计算相位跳变量
47
+ phase_jump = phase_diff[disc_idx]
48
+
49
+ # 在不连续点之后应用相位校正(累积补偿)
50
+ correction_length = min(int(0.02 * sr), len(phase_corrected) - disc_idx - 1) # 20ms
51
+ if correction_length > 0:
52
+ # 线性过渡相位校正
53
+ correction_curve = np.linspace(phase_jump, 0, correction_length)
54
+ phase_corrected[disc_idx + 1:disc_idx + 1 + correction_length] -= correction_curve
55
+
56
+ # 用校正后的相位重建信号
57
+ corrected_signal = amplitude * np.exp(1j * phase_corrected)
58
+ result = np.real(corrected_signal).astype(np.float32)
59
+
60
+ return result
61
+
62
+
63
+ def reduce_breath_electric_noise(audio: np.ndarray, sr: int, f0: Optional[np.ndarray] = None) -> np.ndarray:
64
+ """
65
+ 减少呼吸音中的电音
66
+
67
+ 参考: GitHub Issue #65 "Artefacting when speech has breath"
68
+ 问题: Vocoder在F0=0的区域会产生电子噪声
69
+
70
+ Args:
71
+ audio: 音频数据
72
+ sr: 采样率
73
+ f0: F0序列(可选,用于定位呼吸音)
74
+
75
+ Returns:
76
+ 处理后的音频
77
+ """
78
+ # 第一步:去除DC偏移和极低频噪声(0-80Hz)
79
+ # 这是vocoder常见的低频泄漏问题
80
+ from scipy import signal as scipy_signal
81
+
82
+ # 设计高通滤波器:80Hz截止
83
+ nyquist = sr / 2
84
+ cutoff = 80 / nyquist
85
+
86
+ # 使用4阶Butterworth高通滤波器
87
+ sos = scipy_signal.butter(4, cutoff, btype='highpass', output='sos')
88
+ audio = scipy_signal.sosfilt(sos, audio)
89
+
90
+ # 第二步:检测和清理宽频噪声(原有逻辑)
91
+ # 检测低能量区域(可能是呼吸音)
92
+ frame_length = int(0.02 * sr) # 20ms
93
+ hop_length = int(0.01 * sr) # 10ms
94
+
95
+ n_frames = 1 + (len(audio) - frame_length) // hop_length
96
+
97
+ # 计算每帧的能量和频谱平坦度
98
+ energy = np.zeros(n_frames)
99
+ spectral_flatness = np.zeros(n_frames)
100
+ high_freq_ratio = np.zeros(n_frames) # 新增:高频能量占比
101
+
102
+ for i in range(n_frames):
103
+ start = i * hop_length
104
+ end = start + frame_length
105
+ if end > len(audio):
106
+ break
107
+
108
+ frame = audio[start:end]
109
+
110
+ # 能量
111
+ energy[i] = np.sum(frame ** 2)
112
+
113
+ # 频谱平坦度(噪声特征)
114
+ fft = np.abs(np.fft.rfft(frame))
115
+ if np.sum(fft) > 1e-10:
116
+ geometric_mean = np.exp(np.mean(np.log(fft + 1e-10)))
117
+ arithmetic_mean = np.mean(fft)
118
+ spectral_flatness[i] = geometric_mean / (arithmetic_mean + 1e-10)
119
+
120
+ # 计算高频能量占比(4kHz以上)
121
+ freqs = np.fft.rfftfreq(len(frame), 1/sr)
122
+ high_freq_mask = freqs >= 4000
123
+ high_freq_energy = np.sum(fft[high_freq_mask] ** 2)
124
+ total_freq_energy = np.sum(fft ** 2)
125
+ high_freq_ratio[i] = high_freq_energy / (total_freq_energy + 1e-10)
126
+
127
+ # 归一化能量
128
+ energy_db = 10 * np.log10(energy + 1e-10)
129
+
130
+ # 自适应底噪检测:
131
+ # 1. 计算能量分布的统计特征
132
+ # 2. 使用最低5%作为候选底噪区域
133
+ # 3. 在候选区域中,根据频谱特征进一步筛选
134
+
135
+ # 候选底噪区域:最低5%能量
136
+ candidate_threshold = np.percentile(energy_db, 5)
137
+
138
+ # 在候选区域中,检测真正的底噪
139
+ # 底噪类型1:宽频噪声(频谱平坦度 > 0.35)
140
+ # 底噪类型2:高频电流声(高频占比 > 0.15)
141
+ is_candidate = energy_db < candidate_threshold
142
+ is_wideband_noise = is_candidate & (spectral_flatness > 0.35)
143
+ is_highfreq_noise = is_candidate & (high_freq_ratio > 0.15)
144
+
145
+ # 合并两种类型的底噪
146
+ is_noise = is_wideband_noise | is_highfreq_noise
147
+
148
+ # 如果检测到的底噪帧数太少(<1%),说明音频本身很纯净,不需要处理
149
+ noise_ratio = is_noise.sum() / len(is_noise)
150
+ if noise_ratio < 0.01:
151
+ return audio
152
+
153
+ # 如果提供了F0,使用F0=0来辅助判断
154
+ if f0 is not None and len(f0) > 0:
155
+ # F0对齐到音频帧
156
+ f0_per_audio_frame = len(f0) / n_frames
157
+ for i in range(n_frames):
158
+ if not is_noise[i]:
159
+ continue
160
+
161
+ f0_idx = int(i * f0_per_audio_frame)
162
+ if f0_idx < len(f0):
163
+ # 如果F0>0,说明有音高,不是底噪
164
+ if f0[f0_idx] > 0:
165
+ is_noise[i] = False
166
+
167
+ # 使用is_noise替代is_breath,更准确地描述我们要处理的内容
168
+ is_breath = is_noise
169
+
170
+ # 根据底噪比例动态调整清理强度
171
+ # 底噪越多,说明vocoder质量越差,需要更激进的清理
172
+ if noise_ratio < 0.05:
173
+ # 底噪很少(1-5%),温和清理
174
+ spectral_threshold_percentile = 85 # 保留15%
175
+ magnitude_attenuation = 0.2 # 衰减到20%
176
+ mix_ratio = 0.5 # 50%清理
177
+ elif noise_ratio < 0.15:
178
+ # 底噪中等(5-15%),中等清理
179
+ spectral_threshold_percentile = 90 # 保留10%
180
+ magnitude_attenuation = 0.1 # 衰减到10%
181
+ mix_ratio = 0.7 # 70%清理
182
+ else:
183
+ # 底噪很多(>15%),激进清理
184
+ spectral_threshold_percentile = 95 # 保留5%
185
+ magnitude_attenuation = 0.05 # 衰减到5%
186
+ mix_ratio = 0.85 # 85%清理
187
+
188
+ # 对底噪区域应用降噪
189
+ result = audio.copy()
190
+
191
+ for i in range(n_frames):
192
+ if is_breath[i]:
193
+ start = i * hop_length
194
+ end = start + frame_length
195
+ if end > len(audio):
196
+ break
197
+
198
+ # 使用频谱门限降噪
199
+ frame = audio[start:end]
200
+
201
+ # FFT
202
+ fft = np.fft.rfft(frame)
203
+ magnitude = np.abs(fft)
204
+ phase = np.angle(fft)
205
+ freqs = np.fft.rfftfreq(len(frame), 1/sr)
206
+
207
+ # 检测这一帧是高频噪声还是宽频噪声
208
+ high_freq_mask = freqs >= 4000
209
+ high_freq_energy = np.sum(magnitude[high_freq_mask] ** 2)
210
+ total_freq_energy = np.sum(magnitude ** 2)
211
+ frame_high_ratio = high_freq_energy / (total_freq_energy + 1e-10)
212
+
213
+ if frame_high_ratio > 0.15:
214
+ # 高频电流声:专门衰减高频部分
215
+ magnitude[high_freq_mask] *= 0.05 # 高频衰减到5%
216
+ # 中频(1-4kHz)温和衰减
217
+ mid_freq_mask = (freqs >= 1000) & (freqs < 4000)
218
+ magnitude[mid_freq_mask] *= 0.3
219
+ else:
220
+ # 宽频噪声:使用原有的频谱门限
221
+ threshold = np.percentile(magnitude, spectral_threshold_percentile)
222
+ magnitude = np.where(magnitude > threshold, magnitude, magnitude * magnitude_attenuation)
223
+
224
+ # 重建
225
+ fft_cleaned = magnitude * np.exp(1j * phase)
226
+ frame_cleaned = np.fft.irfft(fft_cleaned, n=len(frame))
227
+
228
+ # 平滑过渡
229
+ fade_length = min(hop_length // 2, len(frame) // 4)
230
+ if fade_length > 0:
231
+ fade_in = np.linspace(0, 1, fade_length)
232
+ fade_out = np.linspace(1, 0, fade_length)
233
+
234
+ frame_cleaned[:fade_length] *= fade_in
235
+ frame_cleaned[-fade_length:] *= fade_out
236
+
237
+ # 动态混合比例
238
+ result[start:end] = frame * (1 - mix_ratio) + frame_cleaned * mix_ratio
239
+
240
+ return result
241
+
242
+
243
+ def stabilize_sustained_notes(audio: np.ndarray, sr: int, f0: Optional[np.ndarray] = None) -> np.ndarray:
244
+ """
245
+ 稳定长音,防止撕裂
246
+
247
+ 参考: "Mel Spectrogram Inversion with Stable Pitch" - Apple Research
248
+ 长音时vocoder容易产生相位漂移
249
+
250
+ Args:
251
+ audio: 音频数据
252
+ sr: 采样率
253
+ f0: F0序列(用于检测长音)
254
+
255
+ Returns:
256
+ 稳定后的音频
257
+ """
258
+ if f0 is None or len(f0) == 0:
259
+ return audio
260
+
261
+ # 检测长音区域(F0稳定且持续时间长)
262
+ frame_length = int(0.02 * sr)
263
+ hop_length = int(0.01 * sr)
264
+
265
+ # F0对齐到音频帧
266
+ n_audio_frames = 1 + (len(audio) - frame_length) // hop_length
267
+ f0_per_audio_frame = len(f0) / n_audio_frames
268
+
269
+ is_sustained = np.zeros(n_audio_frames, dtype=bool)
270
+
271
+ # 检测F0稳定的区域
272
+ window_size = 20 # 200ms窗口
273
+ for i in range(window_size, n_audio_frames - window_size):
274
+ f0_idx = int(i * f0_per_audio_frame)
275
+ if f0_idx >= len(f0):
276
+ break
277
+
278
+ # 获取窗口内的F0
279
+ f0_window_start = max(0, f0_idx - window_size)
280
+ f0_window_end = min(len(f0), f0_idx + window_size)
281
+ f0_window = f0[f0_window_start:f0_window_end]
282
+
283
+ # 过滤F0=0
284
+ f0_voiced = f0_window[f0_window > 0]
285
+
286
+ if len(f0_voiced) > window_size * 0.8: # 80%有声
287
+ # 计算F0稳定性
288
+ f0_std = np.std(f0_voiced)
289
+ f0_mean = np.mean(f0_voiced)
290
+
291
+ # F0变化小于5%认为是长音
292
+ if f0_std / (f0_mean + 1e-6) < 0.05:
293
+ is_sustained[i] = True
294
+
295
+ # 对长音区域应用相位稳定
296
+ result = audio.copy()
297
+
298
+ i = 0
299
+ while i < n_audio_frames:
300
+ if is_sustained[i]:
301
+ # 找到长音区域的起止
302
+ start_frame = i
303
+ while i < n_audio_frames and is_sustained[i]:
304
+ i += 1
305
+ end_frame = i
306
+
307
+ # 转换为样本索引
308
+ start_sample = start_frame * hop_length
309
+ end_sample = min(end_frame * hop_length + frame_length, len(audio))
310
+
311
+ if end_sample - start_sample < frame_length:
312
+ continue
313
+
314
+ # 提取长音段
315
+ sustained_segment = audio[start_sample:end_sample]
316
+
317
+ # 使用低通滤波平滑幅度包络(而非除法)
318
+ envelope = np.abs(signal.hilbert(sustained_segment))
319
+
320
+ # 平滑包络
321
+ b, a = signal.butter(2, 50 / (sr / 2), btype='low')
322
+ smoothed_envelope = signal.filtfilt(b, a, envelope)
323
+
324
+ # 计算增益调整(避免除法放大噪声)
325
+ # 只在包络变化剧烈的地方应用平滑
326
+ envelope_variation = np.abs(envelope - smoothed_envelope)
327
+ variation_threshold = np.percentile(envelope_variation, 75)
328
+
329
+ # 创建混合掩码:变化大的地方用平滑包络,变化小的地方保持原样
330
+ blend_mask = np.clip(envelope_variation / (variation_threshold + 1e-6), 0, 1)
331
+
332
+ # 计算目标包络
333
+ target_envelope = smoothed_envelope * blend_mask + envelope * (1 - blend_mask)
334
+
335
+ # 应用包络调整(使用乘法而非除法)
336
+ if np.max(envelope) > 1e-6:
337
+ gain = target_envelope / (envelope + 1e-6)
338
+ # 限制增益范围,避免放大噪声
339
+ gain = np.clip(gain, 0.5, 2.0)
340
+ result[start_sample:end_sample] = sustained_segment * gain
341
+
342
+ i += 1
343
+
344
+ return result
345
+
346
+
347
+ def apply_vocoder_artifact_fix(
348
+ audio: np.ndarray,
349
+ sr: int,
350
+ f0: Optional[np.ndarray] = None,
351
+ chunk_boundaries: Optional[list] = None,
352
+ fix_phase: bool = True,
353
+ fix_breath: bool = True,
354
+ fix_sustained: bool = True
355
+ ) -> np.ndarray:
356
+ """
357
+ 应用完整的vocoder伪影修复
358
+
359
+ Args:
360
+ audio: 音频数据
361
+ sr: 采样率
362
+ f0: F0序列
363
+ chunk_boundaries: 分块边界
364
+ fix_phase: 是否修复相位不连续
365
+ fix_breath: 是否修复呼吸音电音
366
+ fix_sustained: 是否稳定长音
367
+
368
+ Returns:
369
+ 修复后的音频
370
+ """
371
+ result = audio.copy()
372
+
373
+ # 1. 修复相位不连续(长音撕裂)
374
+ if fix_phase:
375
+ result = fix_phase_discontinuity(result, sr, chunk_boundaries)
376
+
377
+ # 2. 减少呼吸音电音
378
+ if fix_breath:
379
+ result = reduce_breath_electric_noise(result, sr, f0)
380
+
381
+ # 3. 稳定长音
382
+ if fix_sustained:
383
+ result = stabilize_sustained_notes(result, sr, f0)
384
+
385
+ return result