File size: 4,056 Bytes
64ceedd d8be50a 956e325 d8be50a 64ceedd d8be50a 64ceedd d8be50a 64ceedd d8be50a 64ceedd d8be50a a34a59d 64ceedd d8be50a 64ceedd d8be50a 64ceedd d8be50a 64ceedd d8be50a 64ceedd d8be50a 64ceedd d8be50a 64ceedd d8be50a 64ceedd b6c45cb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
# DPTNet_quant_sep.py
import warnings
warnings.filterwarnings("ignore", message="Failed to initialize NumPy: _ARRAY_API not found")
import os
import torch
import numpy as np
import torchaudio
from huggingface_hub import hf_hub_download
# 動態導入 asteroid_test 中的 DPTNet
try:
from . import asteroid_test
except ImportError as e:
raise ImportError("無法載入 asteroid_test 模組,請確認該模組與訓練時相同") from e
torchaudio.set_audio_backend("sox_io")
def get_conf():
"""取得模型參數設定"""
conf_filterbank = {
'n_filters': 64,
'kernel_size': 16,
'stride': 8
}
conf_masknet = {
'in_chan': 64,
'n_src': 2,
'out_chan': 64,
'ff_hid': 256,
'ff_activation': "relu",
'norm_type': "gLN",
'chunk_size': 100,
'hop_size': 50,
'n_repeats': 2,
'mask_act': 'sigmoid',
'bidirectional': True,
'dropout': 0
}
return conf_filterbank, conf_masknet
def load_dpt_model():
print('Load Separation Model...')
# 從環境變數取得 Secret Token
speech_sep_token = os.getenv("SpeechSeparation")
if not speech_sep_token:
raise EnvironmentError("環境變數 SpeechSeparation 未設定!")
# 從 HF Hub 下載模型權重
model_path = hf_hub_download(
repo_id="DeepLearning101/speech-separation",
filename="train_dptnet_aishell_partOverlap_B2_300epoch_quan-int8.p",
token=speech_sep_token
)
# 取得模型參數
conf_filterbank, conf_masknet = get_conf()
# 建立模型架構(⚠️ 這邊要與訓練時完全一樣)
try:
model_class = getattr(asteroid_test, "DPTNet")
model = model_class(**conf_filterbank, **conf_masknet)
except Exception as e:
raise RuntimeError("模型結構錯誤:請確認 asteroid_test.py 是否與訓練時相同") from e
# 套用量化設定
try:
model = torch.quantization.quantize_dynamic(
model,
{torch.nn.LSTM, torch.nn.Linear},
dtype=torch.qint8
)
except Exception as e:
print("量化設定失敗:", e)
# 載入權重(忽略不匹配的 keys)
state_dict = torch.load(model_path, map_location="cpu")
own_state = model.state_dict()
filtered_state_dict = {
k: v for k, v in state_dict.items() if k in own_state and v.shape == own_state[k].shape
}
# 忽略找不到的 keys,也不強制要求全部 match
missing_keys, unexpected_keys = model.load_state_dict(filtered_state_dict, strict=False)
# 印出警告訊息方便除錯
if missing_keys:
print("⚠️ Missing keys:", missing_keys)
if unexpected_keys:
print("ℹ️ Unexpected keys:", unexpected_keys)
model.eval()
return model
def dpt_sep_process(wav_path, model=None, outfilename=None):
"""進行語音分離處理"""
if model is None:
model = load_dpt_model()
x, sr = torchaudio.load(wav_path)
x = x.cpu()
with torch.no_grad():
est_sources = model(x) # shape: (1, 2, T)
est_sources = est_sources.squeeze(0) # shape: (2, T)
sep_1, sep_2 = est_sources # 拆成兩個 (T,) 的 tensor
# 正規化
max_abs = x[0].abs().max().item()
sep_1 = sep_1 * max_abs / sep_1.abs().max().item()
sep_2 = sep_2 * max_abs / sep_2.abs().max().item()
# 增加 channel 維度,變為 (1, T)
sep_1 = sep_1.unsqueeze(0)
sep_2 = sep_2.unsqueeze(0)
# 儲存結果
if outfilename is not None:
torchaudio.save(outfilename.replace('.wav', '_sep1.wav'), sep_1, sr)
torchaudio.save(outfilename.replace('.wav', '_sep2.wav'), sep_2, sr)
torchaudio.save(outfilename.replace('.wav', '_mix.wav'), x, sr)
else:
torchaudio.save(wav_path.replace('.wav', '_sep1.wav'), sep_1, sr)
torchaudio.save(wav_path.replace('.wav', '_sep2.wav'), sep_2, sr)
if __name__ == '__main__':
print("This module should be used via Flask or Gradio.") |