DeepLearning101 commited on
Commit
38d7181
·
verified ·
1 Parent(s): cf73d23

Update DPTNet_eval/DPTNet_quant_sep.py

Browse files
Files changed (1) hide show
  1. DPTNet_eval/DPTNet_quant_sep.py +24 -26
DPTNet_eval/DPTNet_quant_sep.py CHANGED
@@ -4,6 +4,9 @@ import numpy as np
4
  import torchaudio
5
  import yaml
6
  from . import asteroid_test
 
 
 
7
 
8
 
9
  def get_conf():
@@ -32,19 +35,35 @@ def get_conf():
32
 
33
  def load_dpt_model():
34
  print('Load Separation Model...')
35
- now_path = os.path.split(os.path.realpath(__file__))[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  conf_filterbank, conf_masknet = get_conf()
37
- model_path = os.path.join(now_path, "trained_model/train_dptnet_aishell_partOverlap_B2_300epoch_quan-int8.p")
38
- model = getattr(asteroid_test, "DPTNet")(**conf_filterbank, **conf_masknet)
39
  model = torch.quantization.quantize_dynamic(model, {torch.nn.LSTM, torch.nn.Linear}, dtype=torch.qint8)
 
40
  state_dict = torch.load(model_path, map_location="cpu")
41
  model.load_state_dict(state_dict)
42
  model.eval()
43
  return model
44
 
 
45
  def dpt_sep_process(wav_path, model=None, outfilename=None):
46
  if model is None:
47
- model = load_model()
48
 
49
  x, sr = torchaudio.load(wav_path)
50
  x = x.cpu()
@@ -73,28 +92,7 @@ def dpt_sep_process(wav_path, model=None, outfilename=None):
73
  else:
74
  torchaudio.save(wav_path.replace('.wav', '_sep1.wav'), sep_1, sr)
75
  torchaudio.save(wav_path.replace('.wav', '_sep2.wav'), sep_2, sr)
76
-
77
- # def dpt_sep_process(wav_path, model=None, outfilename=None):
78
- # if model == None:
79
- # model = load_model()
80
- # x, sr = torchaudio.load(wav_path)
81
- # x = x.cpu()
82
- # with torch.no_grad():
83
- # est_sources = model(x)
84
-
85
- # est_sources_np = est_sources.squeeze(0)
86
-
87
- # sep_1, sep_2 = est_sources_np
88
- # sep_1 = sep_1 * x[0].abs().max().item() / sep_1.abs().max().item()
89
- # sep_2 = sep_2 * x[0].abs().max().item() / sep_2.abs().max().item()
90
-
91
- # if outfilename != None:
92
- # torchaudio.save(outfilename.replace('.wav', '_sep1.wav'), sep_1, sr)
93
- # torchaudio.save(outfilename.replace('.wav', '_sep2.wav'), sep_2, sr)
94
- # torchaudio.save(outfilename.replace('.wav', '_mix.wav'), x, sr)
95
- # else:
96
- # torchaudio.save(wav_path.replace('.wav', '_sep1.wav'), sep_1, sr)
97
- # torchaudio.save(wav_path.replace('.wav', '_sep2.wav'), sep_2, sr)
98
 
99
  if __name__ == '__main__':
100
  print("This module should be used via Flask or Gradio.")
 
4
  import torchaudio
5
  import yaml
6
  from . import asteroid_test
7
+ from huggingface_hub import hf_hub_download
8
+
9
+ torchaudio.set_audio_backend("sox_io")
10
 
11
 
12
  def get_conf():
 
35
 
36
  def load_dpt_model():
37
  print('Load Separation Model...')
38
+
39
+ # 👇 從環境變數取得 HF Token
40
+ from huggingface_hub import hf_hub_download
41
+ speech_sep_token = os.getenv("SpeechSeparation")
42
+ if not speech_sep_token:
43
+ raise EnvironmentError("環境變數 SpeechSeparation 未設定!")
44
+
45
+ # 👇 從 Hugging Face Hub 下載模型權重
46
+ model_path = hf_hub_download(
47
+ repo_id="DeepLearning101/speech-separation", # 替換成你自己的 repo 名稱
48
+ filename="train_dptnet_aishell_partOverlap_B2_300epoch_quan-int8.p",
49
+ token=speech_sep_token
50
+ )
51
+
52
+ # 👇 原本邏輯完全不變
53
  conf_filterbank, conf_masknet = get_conf()
54
+ model_class = getattr(asteroid_test, "DPTNet")
55
+ model = model_class(**conf_filterbank, **conf_masknet)
56
  model = torch.quantization.quantize_dynamic(model, {torch.nn.LSTM, torch.nn.Linear}, dtype=torch.qint8)
57
+
58
  state_dict = torch.load(model_path, map_location="cpu")
59
  model.load_state_dict(state_dict)
60
  model.eval()
61
  return model
62
 
63
+
64
  def dpt_sep_process(wav_path, model=None, outfilename=None):
65
  if model is None:
66
+ model = load_dpt_model()
67
 
68
  x, sr = torchaudio.load(wav_path)
69
  x = x.cpu()
 
92
  else:
93
  torchaudio.save(wav_path.replace('.wav', '_sep1.wav'), sep_1, sr)
94
  torchaudio.save(wav_path.replace('.wav', '_sep2.wav'), sep_2, sr)
95
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
  if __name__ == '__main__':
98
  print("This module should be used via Flask or Gradio.")