File size: 4,056 Bytes
64ceedd
d8be50a
956e325
 
d8be50a
64ceedd
 
 
 
 
d8be50a
 
 
 
 
 
64ceedd
 
 
d8be50a
64ceedd
d8be50a
64ceedd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d8be50a
a34a59d
 
 
64ceedd
d8be50a
64ceedd
d8be50a
 
 
 
64ceedd
 
 
 
d8be50a
 
 
 
 
 
64ceedd
 
d8be50a
 
 
 
 
 
 
 
64ceedd
 
 
d8be50a
 
 
 
 
 
 
64ceedd
d8be50a
 
 
 
 
 
 
64ceedd
 
 
 
d8be50a
64ceedd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b6c45cb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
# DPTNet_quant_sep.py

import warnings
warnings.filterwarnings("ignore", message="Failed to initialize NumPy: _ARRAY_API not found")

import os
import torch
import numpy as np
import torchaudio
from huggingface_hub import hf_hub_download

# 動態導入 asteroid_test 中的 DPTNet
try:
    from . import asteroid_test
except ImportError as e:
    raise ImportError("無法載入 asteroid_test 模組,請確認該模組與訓練時相同") from e

torchaudio.set_audio_backend("sox_io")


def get_conf():
    """取得模型參數設定"""
    conf_filterbank = {
        'n_filters': 64,
        'kernel_size': 16,
        'stride': 8
    }

    conf_masknet = {
        'in_chan': 64,
        'n_src': 2,
        'out_chan': 64,
        'ff_hid': 256,
        'ff_activation': "relu",
        'norm_type': "gLN",
        'chunk_size': 100,
        'hop_size': 50,
        'n_repeats': 2,
        'mask_act': 'sigmoid',
        'bidirectional': True,
        'dropout': 0
    }
    return conf_filterbank, conf_masknet


def load_dpt_model():
    print('Load Separation Model...')

    # 從環境變數取得 Secret Token
    speech_sep_token = os.getenv("SpeechSeparation")
    if not speech_sep_token:
        raise EnvironmentError("環境變數 SpeechSeparation 未設定!")

    # 從 HF Hub 下載模型權重
    model_path = hf_hub_download(
        repo_id="DeepLearning101/speech-separation",
        filename="train_dptnet_aishell_partOverlap_B2_300epoch_quan-int8.p",
        token=speech_sep_token
    )

    # 取得模型參數
    conf_filterbank, conf_masknet = get_conf()

    # 建立模型架構(⚠️ 這邊要與訓練時完全一樣)
    try:
        model_class = getattr(asteroid_test, "DPTNet")
        model = model_class(**conf_filterbank, **conf_masknet)
    except Exception as e:
        raise RuntimeError("模型結構錯誤:請確認 asteroid_test.py 是否與訓練時相同") from e

    # 套用量化設定
    try:
        model = torch.quantization.quantize_dynamic(
            model,
            {torch.nn.LSTM, torch.nn.Linear},
            dtype=torch.qint8
        )
    except Exception as e:
        print("量化設定失敗:", e)

    # 載入權重(忽略不匹配的 keys)
    state_dict = torch.load(model_path, map_location="cpu")
    own_state = model.state_dict()
    filtered_state_dict = {
        k: v for k, v in state_dict.items() if k in own_state and v.shape == own_state[k].shape
    }

    # 忽略找不到的 keys,也不強制要求全部 match
    missing_keys, unexpected_keys = model.load_state_dict(filtered_state_dict, strict=False)

    # 印出警告訊息方便除錯
    if missing_keys:
        print("⚠️ Missing keys:", missing_keys)
    if unexpected_keys:
        print("ℹ️ Unexpected keys:", unexpected_keys)

    model.eval()
    return model


def dpt_sep_process(wav_path, model=None, outfilename=None):
    """進行語音分離處理"""
    if model is None:
        model = load_dpt_model()

    x, sr = torchaudio.load(wav_path)
    x = x.cpu()

    with torch.no_grad():
        est_sources = model(x)  # shape: (1, 2, T)

    est_sources = est_sources.squeeze(0)  # shape: (2, T)
    sep_1, sep_2 = est_sources  # 拆成兩個 (T,) 的 tensor

    # 正規化
    max_abs = x[0].abs().max().item()
    sep_1 = sep_1 * max_abs / sep_1.abs().max().item()
    sep_2 = sep_2 * max_abs / sep_2.abs().max().item()

    # 增加 channel 維度,變為 (1, T)
    sep_1 = sep_1.unsqueeze(0)
    sep_2 = sep_2.unsqueeze(0)

    # 儲存結果
    if outfilename is not None:
        torchaudio.save(outfilename.replace('.wav', '_sep1.wav'), sep_1, sr)
        torchaudio.save(outfilename.replace('.wav', '_sep2.wav'), sep_2, sr)
        torchaudio.save(outfilename.replace('.wav', '_mix.wav'), x, sr)
    else:
        torchaudio.save(wav_path.replace('.wav', '_sep1.wav'), sep_1, sr)
        torchaudio.save(wav_path.replace('.wav', '_sep2.wav'), sep_2, sr)


if __name__ == '__main__':
    print("This module should be used via Flask or Gradio.")