DeepLearning101 commited on
Commit
d8be50a
·
verified ·
1 Parent(s): 956e325

Update DPTNet_eval/DPTNet_quant_sep.py

Browse files
Files changed (1) hide show
  1. DPTNet_eval/DPTNet_quant_sep.py +45 -18
DPTNet_eval/DPTNet_quant_sep.py CHANGED
@@ -1,16 +1,25 @@
1
  # DPTNet_quant_sep.py
 
2
  import warnings
3
  warnings.filterwarnings("ignore", message="Failed to initialize NumPy: _ARRAY_API not found")
 
4
  import os
5
  import torch
6
  import numpy as np
7
  import torchaudio
8
  from huggingface_hub import hf_hub_download
9
- from . import asteroid_test
 
 
 
 
 
10
 
11
  torchaudio.set_audio_backend("sox_io")
12
 
 
13
  def get_conf():
 
14
  conf_filterbank = {
15
  'n_filters': 64,
16
  'kernel_size': 16,
@@ -37,42 +46,60 @@ def get_conf():
37
  def load_dpt_model():
38
  print('Load Separation Model...')
39
 
40
- # 從環境變數取得你設定的 Secret 名稱為 SpeechSeparation
41
  speech_sep_token = os.getenv("SpeechSeparation")
42
  if not speech_sep_token:
43
  raise EnvironmentError("環境變數 SpeechSeparation 未設定!")
44
 
 
45
  model_path = hf_hub_download(
46
- repo_id="DeepLearning101/speech-separation",
47
- filename="train_dptnet_aishell_partOverlap_B2_300epoch_quan-int8.p",
48
- token=speech_sep_token # ✅ 使用你設定的 Secret
49
- )
50
 
51
  # 取得模型參數
52
  conf_filterbank, conf_masknet = get_conf()
53
 
54
- # 建立模型架構
55
- model_class = getattr(asteroid_test, "DPTNet")
56
- model = model_class(**conf_filterbank, **conf_masknet)
 
 
 
57
 
58
  # 套用量化設定
59
- model = torch.quantization.quantize_dynamic(
60
- model,
61
- {torch.nn.LSTM, torch.nn.Linear},
62
- dtype=torch.qint8
63
- )
 
 
 
64
 
65
  # 載入權重(忽略不匹配的 keys)
66
  state_dict = torch.load(model_path, map_location="cpu")
67
- model_state_dict = model.state_dict()
68
- filtered_state_dict = {k: v for k, v in state_dict.items() if k in model_state_dict}
69
- model.load_state_dict(filtered_state_dict, strict=False)
70
- model.eval()
 
 
 
71
 
 
 
 
 
 
 
 
72
  return model
73
 
74
 
75
  def dpt_sep_process(wav_path, model=None, outfilename=None):
 
76
  if model is None:
77
  model = load_dpt_model()
78
 
 
1
  # DPTNet_quant_sep.py
2
+
3
  import warnings
4
  warnings.filterwarnings("ignore", message="Failed to initialize NumPy: _ARRAY_API not found")
5
+
6
  import os
7
  import torch
8
  import numpy as np
9
  import torchaudio
10
  from huggingface_hub import hf_hub_download
11
+
12
+ # 動態導入 asteroid_test 中的 DPTNet
13
+ try:
14
+ from . import asteroid_test
15
+ except ImportError as e:
16
+ raise ImportError("無法載入 asteroid_test 模組,請確認該模組與訓練時相同") from e
17
 
18
  torchaudio.set_audio_backend("sox_io")
19
 
20
+
21
  def get_conf():
22
+ """取得模型參數設定"""
23
  conf_filterbank = {
24
  'n_filters': 64,
25
  'kernel_size': 16,
 
46
  def load_dpt_model():
47
  print('Load Separation Model...')
48
 
49
+ # 從環境變數取得 Secret Token
50
  speech_sep_token = os.getenv("SpeechSeparation")
51
  if not speech_sep_token:
52
  raise EnvironmentError("環境變數 SpeechSeparation 未設定!")
53
 
54
+ # 從 HF Hub 下載模型權重
55
  model_path = hf_hub_download(
56
+ repo_id="DeepLearning101/speech-separation",
57
+ filename="train_dptnet_aishell_partOverlap_B2_300epoch_quan-int8.p",
58
+ token=speech_sep_token
59
+ )
60
 
61
  # 取得模型參數
62
  conf_filterbank, conf_masknet = get_conf()
63
 
64
+ # 建立模型架構(⚠️ 這邊要與訓練時完全一樣)
65
+ try:
66
+ model_class = getattr(asteroid_test, "DPTNet")
67
+ model = model_class(**conf_filterbank, **conf_masknet)
68
+ except Exception as e:
69
+ raise RuntimeError("模型結構錯誤:請確認 asteroid_test.py 是否與訓練時相同") from e
70
 
71
  # 套用量化設定
72
+ try:
73
+ model = torch.quantization.quantize_dynamic(
74
+ model,
75
+ {torch.nn.LSTM, torch.nn.Linear},
76
+ dtype=torch.qint8
77
+ )
78
+ except Exception as e:
79
+ print("量化設定失敗:", e)
80
 
81
  # 載入權重(忽略不匹配的 keys)
82
  state_dict = torch.load(model_path, map_location="cpu")
83
+ own_state = model.state_dict()
84
+ filtered_state_dict = {
85
+ k: v for k, v in state_dict.items() if k in own_state and v.shape == own_state[k].shape
86
+ }
87
+
88
+ # 忽略找不到的 keys,也不強制要求全部 match
89
+ missing_keys, unexpected_keys = model.load_state_dict(filtered_state_dict, strict=False)
90
 
91
+ # 印出警告訊息方便除錯
92
+ if missing_keys:
93
+ print("⚠️ Missing keys:", missing_keys)
94
+ if unexpected_keys:
95
+ print("ℹ️ Unexpected keys:", unexpected_keys)
96
+
97
+ model.eval()
98
  return model
99
 
100
 
101
  def dpt_sep_process(wav_path, model=None, outfilename=None):
102
+ """進行語音分離處理"""
103
  if model is None:
104
  model = load_dpt_model()
105