Spaces:
Sleeping
Sleeping
Upload 16 files
Browse files- DPTNet_eval/DPTNet_quant_sep.py +108 -0
- DPTNet_eval/asteroid_test/__init__.py +19 -0
- DPTNet_eval/asteroid_test/dsp/__init__.py +5 -0
- DPTNet_eval/asteroid_test/dsp/overlap_add.py +317 -0
- DPTNet_eval/asteroid_test/filterbanks/__init__.py +107 -0
- DPTNet_eval/asteroid_test/filterbanks/enc_dec.py +267 -0
- DPTNet_eval/asteroid_test/filterbanks/free_fb.py +33 -0
- DPTNet_eval/asteroid_test/masknn/__init__.py +12 -0
- DPTNet_eval/asteroid_test/masknn/activations.py +82 -0
- DPTNet_eval/asteroid_test/masknn/attention.py +271 -0
- DPTNet_eval/asteroid_test/masknn/norms.py +156 -0
- DPTNet_eval/asteroid_test/models/__init__.py +59 -0
- DPTNet_eval/asteroid_test/models/base_models.py +351 -0
- DPTNet_eval/asteroid_test/models/dptnet.py +96 -0
- DPTNet_eval/asteroid_test/utils/__init__.py +9 -0
- DPTNet_eval/asteroid_test/utils/torch_utils.py +126 -0
DPTNet_eval/DPTNet_quant_sep.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DPTNet_quant_sep.py
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import torch
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torchaudio
|
| 7 |
+
from huggingface_hub import hf_hub_download
|
| 8 |
+
from . import asteroid_test
|
| 9 |
+
|
| 10 |
+
torchaudio.set_audio_backend("sox_io")
|
| 11 |
+
|
| 12 |
+
def get_conf():
|
| 13 |
+
conf_filterbank = {
|
| 14 |
+
'n_filters': 64,
|
| 15 |
+
'kernel_size': 16,
|
| 16 |
+
'stride': 8
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
conf_masknet = {
|
| 20 |
+
'in_chan': 64,
|
| 21 |
+
'n_src': 2,
|
| 22 |
+
'out_chan': 64,
|
| 23 |
+
'ff_hid': 256,
|
| 24 |
+
'ff_activation': "relu",
|
| 25 |
+
'norm_type': "gLN",
|
| 26 |
+
'chunk_size': 100,
|
| 27 |
+
'hop_size': 50,
|
| 28 |
+
'n_repeats': 2,
|
| 29 |
+
'mask_act': 'sigmoid',
|
| 30 |
+
'bidirectional': True,
|
| 31 |
+
'dropout': 0
|
| 32 |
+
}
|
| 33 |
+
return conf_filterbank, conf_masknet
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def load_dpt_model():
|
| 37 |
+
print('Load Separation Model...')
|
| 38 |
+
|
| 39 |
+
# 從環境變數取得 Hugging Face Token
|
| 40 |
+
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 41 |
+
if not HF_TOKEN:
|
| 42 |
+
raise EnvironmentError("環境變數 HF_TOKEN 未設定!請先執行 export HF_TOKEN=xxx")
|
| 43 |
+
|
| 44 |
+
# 從 Hugging Face Hub 下載模型權重
|
| 45 |
+
model_path = hf_hub_download(
|
| 46 |
+
repo_id="DeepLearning101/speech-separation", # ← 替換成你的 repo 名稱
|
| 47 |
+
filename="train_dptnet_aishell_partOverlap_B2_300epoch_quan-int8.p",
|
| 48 |
+
token=HF_TOKEN
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
# 取得模型參數
|
| 52 |
+
conf_filterbank, conf_masknet = get_conf()
|
| 53 |
+
|
| 54 |
+
# 建立模型架構
|
| 55 |
+
model_class = getattr(asteroid_test, "DPTNet")
|
| 56 |
+
model = model_class(**conf_filterbank, **conf_masknet)
|
| 57 |
+
|
| 58 |
+
# 套用量化設定
|
| 59 |
+
model = torch.quantization.quantize_dynamic(
|
| 60 |
+
model,
|
| 61 |
+
{torch.nn.LSTM, torch.nn.Linear},
|
| 62 |
+
dtype=torch.qint8
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
# 載入權重(忽略不匹配的 keys)
|
| 66 |
+
state_dict = torch.load(model_path, map_location="cpu")
|
| 67 |
+
model_state_dict = model.state_dict()
|
| 68 |
+
filtered_state_dict = {k: v for k, v in state_dict.items() if k in model_state_dict}
|
| 69 |
+
model.load_state_dict(filtered_state_dict, strict=False)
|
| 70 |
+
model.eval()
|
| 71 |
+
|
| 72 |
+
return model
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def dpt_sep_process(wav_path, model=None, outfilename=None):
|
| 76 |
+
if model is None:
|
| 77 |
+
model = load_dpt_model()
|
| 78 |
+
|
| 79 |
+
x, sr = torchaudio.load(wav_path)
|
| 80 |
+
x = x.cpu()
|
| 81 |
+
|
| 82 |
+
with torch.no_grad():
|
| 83 |
+
est_sources = model(x) # shape: (1, 2, T)
|
| 84 |
+
|
| 85 |
+
est_sources = est_sources.squeeze(0) # shape: (2, T)
|
| 86 |
+
sep_1, sep_2 = est_sources # 拆成兩個 (T,) 的 tensor
|
| 87 |
+
|
| 88 |
+
# 正規化
|
| 89 |
+
max_abs = x[0].abs().max().item()
|
| 90 |
+
sep_1 = sep_1 * max_abs / sep_1.abs().max().item()
|
| 91 |
+
sep_2 = sep_2 * max_abs / sep_2.abs().max().item()
|
| 92 |
+
|
| 93 |
+
# 增加 channel 維度,變為 (1, T)
|
| 94 |
+
sep_1 = sep_1.unsqueeze(0)
|
| 95 |
+
sep_2 = sep_2.unsqueeze(0)
|
| 96 |
+
|
| 97 |
+
# 儲存結果
|
| 98 |
+
if outfilename is not None:
|
| 99 |
+
torchaudio.save(outfilename.replace('.wav', '_sep1.wav'), sep_1, sr)
|
| 100 |
+
torchaudio.save(outfilename.replace('.wav', '_sep2.wav'), sep_2, sr)
|
| 101 |
+
torchaudio.save(outfilename.replace('.wav', '_mix.wav'), x, sr)
|
| 102 |
+
else:
|
| 103 |
+
torchaudio.save(wav_path.replace('.wav', '_sep1.wav'), sep_1, sr)
|
| 104 |
+
torchaudio.save(wav_path.replace('.wav', '_sep2.wav'), sep_2, sr)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
if __name__ == '__main__':
|
| 108 |
+
print("This module should be used via Flask or Gradio.")
|
DPTNet_eval/asteroid_test/__init__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pathlib
|
| 2 |
+
|
| 3 |
+
from .models import DPTNet
|
| 4 |
+
from .utils import torch_utils # noqa
|
| 5 |
+
|
| 6 |
+
project_root = str(pathlib.Path(__file__).expanduser().absolute().parent.parent)
|
| 7 |
+
__version__ = "0.3.4"
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def show_available_models():
|
| 11 |
+
from .utils.hub_utils import MODELS_URLS_HASHTABLE
|
| 12 |
+
|
| 13 |
+
print(" \n".join(list(MODELS_URLS_HASHTABLE.keys())))
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
__all__ = [
|
| 17 |
+
"DPTNet",
|
| 18 |
+
"show_available_models",
|
| 19 |
+
]
|
DPTNet_eval/asteroid_test/dsp/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .overlap_add import DualPathProcessing
|
| 2 |
+
|
| 3 |
+
__all__ = [
|
| 4 |
+
"DualPathProcessing",
|
| 5 |
+
]
|
DPTNet_eval/asteroid_test/dsp/overlap_add.py
ADDED
|
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from scipy.signal import get_window
|
| 3 |
+
# from asteroid_test.losses import PITLossWrapper
|
| 4 |
+
from torch import nn
|
| 5 |
+
|
| 6 |
+
'''
|
| 7 |
+
class LambdaOverlapAdd(torch.nn.Module):
|
| 8 |
+
"""Overlap-add with lambda transform on segments.
|
| 9 |
+
|
| 10 |
+
Segment input signal, apply lambda function (a neural network for example)
|
| 11 |
+
and combine with OLA.
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
nnet (callable): Function to apply to each segment.
|
| 15 |
+
n_src (int): Number of sources in the output of nnet.
|
| 16 |
+
window_size (int): Size of segmenting window.
|
| 17 |
+
hop_size (int): Segmentation hop size.
|
| 18 |
+
window (str): Name of the window (see scipy.signal.get_window) used
|
| 19 |
+
for the synthesis.
|
| 20 |
+
reorder_chunks (bool): Whether to reorder each consecutive segment.
|
| 21 |
+
This might be useful when `nnet` is permutation invariant, as
|
| 22 |
+
source assignements might change output channel from one segment
|
| 23 |
+
to the next (in classic speech separation for example).
|
| 24 |
+
Reordering is performed based on the correlation between
|
| 25 |
+
the overlapped part of consecutive segment.
|
| 26 |
+
|
| 27 |
+
Examples:
|
| 28 |
+
>>> from asteroid_test import ConvTasNet
|
| 29 |
+
>>> nnet = ConvTasNet(n_src=2)
|
| 30 |
+
>>> continuous_nnet = LambdaOverlapAdd(
|
| 31 |
+
>>> nnet=nnet,
|
| 32 |
+
>>> n_src=2,
|
| 33 |
+
>>> window_size=64000,
|
| 34 |
+
>>> hop_size=None,
|
| 35 |
+
>>> window="hanning",
|
| 36 |
+
>>> reorder_chunks=True,
|
| 37 |
+
>>> enable_grad=False,
|
| 38 |
+
>>> )
|
| 39 |
+
>>> wav = torch.randn(1, 1, 500000)
|
| 40 |
+
>>> out_wavs = continuous_nnet.forward(wav)
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
def __init__(
|
| 44 |
+
self,
|
| 45 |
+
nnet,
|
| 46 |
+
n_src,
|
| 47 |
+
window_size,
|
| 48 |
+
hop_size=None,
|
| 49 |
+
window="hanning",
|
| 50 |
+
reorder_chunks=True,
|
| 51 |
+
enable_grad=False,
|
| 52 |
+
):
|
| 53 |
+
super().__init__()
|
| 54 |
+
assert window_size % 2 == 0, "Window size must be even"
|
| 55 |
+
|
| 56 |
+
self.nnet = nnet
|
| 57 |
+
self.window_size = window_size
|
| 58 |
+
self.hop_size = hop_size if hop_size is not None else window_size // 2
|
| 59 |
+
self.n_src = n_src
|
| 60 |
+
|
| 61 |
+
if window:
|
| 62 |
+
window = get_window(window, self.window_size).astype("float32")
|
| 63 |
+
window = torch.from_numpy(window)
|
| 64 |
+
self.use_window = True
|
| 65 |
+
else:
|
| 66 |
+
self.use_window = False
|
| 67 |
+
|
| 68 |
+
self.register_buffer("window", window)
|
| 69 |
+
self.reorder_chunks = reorder_chunks
|
| 70 |
+
self.enable_grad = enable_grad
|
| 71 |
+
|
| 72 |
+
def ola_forward(self, x):
|
| 73 |
+
"""Heart of the class: segment signal, apply func, combine with OLA."""
|
| 74 |
+
|
| 75 |
+
assert x.ndim == 3
|
| 76 |
+
|
| 77 |
+
batch, channels, n_frames = x.size()
|
| 78 |
+
# Overlap and add:
|
| 79 |
+
# [batch, chans, n_frames] -> [batch, chans, win_size, n_chunks]
|
| 80 |
+
unfolded = torch.nn.functional.unfold(
|
| 81 |
+
x.unsqueeze(-1),
|
| 82 |
+
kernel_size=(self.window_size, 1),
|
| 83 |
+
padding=(self.window_size, 0),
|
| 84 |
+
stride=(self.hop_size, 1),
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
out = []
|
| 88 |
+
n_chunks = unfolded.shape[-1]
|
| 89 |
+
for frame_idx in range(n_chunks): # for loop to spare memory
|
| 90 |
+
frame = self.nnet(unfolded[..., frame_idx])
|
| 91 |
+
# user must handle multichannel by reshaping to batch
|
| 92 |
+
if frame_idx == 0:
|
| 93 |
+
assert frame.ndim == 3, "nnet should return (batch, n_src, time)"
|
| 94 |
+
assert frame.shape[1] == self.n_src, "nnet should return (batch, n_src, time)"
|
| 95 |
+
frame = frame.reshape(batch * self.n_src, -1)
|
| 96 |
+
|
| 97 |
+
if frame_idx != 0 and self.reorder_chunks:
|
| 98 |
+
# we determine best perm based on xcorr with previous sources
|
| 99 |
+
frame = _reorder_sources(
|
| 100 |
+
frame, out[-1], self.n_src, self.window_size, self.hop_size
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
if self.use_window:
|
| 104 |
+
frame = frame * self.window
|
| 105 |
+
else:
|
| 106 |
+
frame = frame / (self.window_size / self.hop_size)
|
| 107 |
+
out.append(frame)
|
| 108 |
+
|
| 109 |
+
out = torch.stack(out).reshape(n_chunks, batch * self.n_src, self.window_size)
|
| 110 |
+
out = out.permute(1, 2, 0)
|
| 111 |
+
|
| 112 |
+
out = torch.nn.functional.fold(
|
| 113 |
+
out,
|
| 114 |
+
(n_frames, 1),
|
| 115 |
+
kernel_size=(self.window_size, 1),
|
| 116 |
+
padding=(self.window_size, 0),
|
| 117 |
+
stride=(self.hop_size, 1),
|
| 118 |
+
)
|
| 119 |
+
return out.squeeze(-1).reshape(batch, self.n_src, -1)
|
| 120 |
+
|
| 121 |
+
def forward(self, x):
|
| 122 |
+
"""Forward module: segment signal, apply func, combine with OLA.
|
| 123 |
+
|
| 124 |
+
Args:
|
| 125 |
+
x (:class:`torch.Tensor`): waveform signal of shape (batch, 1, time).
|
| 126 |
+
|
| 127 |
+
Returns:
|
| 128 |
+
:class:`torch.Tensor`: The output of the lambda OLA.
|
| 129 |
+
"""
|
| 130 |
+
# Here we can do the reshaping
|
| 131 |
+
with torch.autograd.set_grad_enabled(self.enable_grad):
|
| 132 |
+
olad = self.ola_forward(x)
|
| 133 |
+
return olad
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def _reorder_sources(
|
| 137 |
+
current: torch.FloatTensor,
|
| 138 |
+
previous: torch.FloatTensor,
|
| 139 |
+
n_src: int,
|
| 140 |
+
window_size: int,
|
| 141 |
+
hop_size: int,
|
| 142 |
+
):
|
| 143 |
+
"""
|
| 144 |
+
Reorder sources in current chunk to maximize correlation with previous chunk.
|
| 145 |
+
Used for Continuous Source Separation. Standard dsp correlation is used
|
| 146 |
+
for reordering.
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
Args:
|
| 150 |
+
current (:class:`torch.Tensor`): current chunk, tensor
|
| 151 |
+
of shape (batch, n_src, window_size)
|
| 152 |
+
previous (:class:`torch.Tensor`): previous chunk, tensor
|
| 153 |
+
of shape (batch, n_src, window_size)
|
| 154 |
+
n_src (:class:`int`): number of sources.
|
| 155 |
+
window_size (:class:`int`): window_size, equal to last dimension of
|
| 156 |
+
both current and previous.
|
| 157 |
+
hop_size (:class:`int`): hop_size between current and previous tensors.
|
| 158 |
+
|
| 159 |
+
Returns:
|
| 160 |
+
current:
|
| 161 |
+
|
| 162 |
+
"""
|
| 163 |
+
batch, frames = current.size()
|
| 164 |
+
current = current.reshape(-1, n_src, frames)
|
| 165 |
+
previous = previous.reshape(-1, n_src, frames)
|
| 166 |
+
|
| 167 |
+
overlap_f = window_size - hop_size
|
| 168 |
+
|
| 169 |
+
def reorder_func(x, y):
|
| 170 |
+
x = x[..., :overlap_f]
|
| 171 |
+
y = y[..., -overlap_f:]
|
| 172 |
+
# Mean normalization
|
| 173 |
+
x = x - x.mean(-1, keepdim=True)
|
| 174 |
+
y = y - y.mean(-1, keepdim=True)
|
| 175 |
+
# Negative mean Correlation
|
| 176 |
+
return -torch.sum(x.unsqueeze(1) * y.unsqueeze(2), dim=-1)
|
| 177 |
+
|
| 178 |
+
# We maximize correlation-like between previous and current.
|
| 179 |
+
pit = PITLossWrapper(reorder_func)
|
| 180 |
+
current = pit(current, previous, return_est=True)[1]
|
| 181 |
+
return current.reshape(batch, frames)
|
| 182 |
+
'''
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
class DualPathProcessing(nn.Module):
|
| 186 |
+
"""Perform Dual-Path processing via overlap-add as in DPRNN [1].
|
| 187 |
+
|
| 188 |
+
Args:
|
| 189 |
+
chunk_size (int): Size of segmenting window.
|
| 190 |
+
hop_size (int): segmentation hop size.
|
| 191 |
+
|
| 192 |
+
References:
|
| 193 |
+
[1] "Dual-path RNN: efficient long sequence modeling for
|
| 194 |
+
time-domain single-channel speech separation", Yi Luo, Zhuo Chen
|
| 195 |
+
and Takuya Yoshioka. https://arxiv.org/abs/1910.06379
|
| 196 |
+
"""
|
| 197 |
+
|
| 198 |
+
def __init__(self, chunk_size, hop_size):
|
| 199 |
+
super(DualPathProcessing, self).__init__()
|
| 200 |
+
self.chunk_size = chunk_size
|
| 201 |
+
self.hop_size = hop_size
|
| 202 |
+
self.n_orig_frames = None
|
| 203 |
+
|
| 204 |
+
def unfold(self, x):
|
| 205 |
+
"""Unfold the feature tensor from
|
| 206 |
+
|
| 207 |
+
(batch, channels, time) to (batch, channels, chunk_size, n_chunks).
|
| 208 |
+
|
| 209 |
+
Args:
|
| 210 |
+
x: (:class:`torch.Tensor`): feature tensor of shape (batch, channels, time).
|
| 211 |
+
|
| 212 |
+
Returns:
|
| 213 |
+
x: (:class:`torch.Tensor`): spliced feature tensor of shape
|
| 214 |
+
(batch, channels, chunk_size, n_chunks).
|
| 215 |
+
|
| 216 |
+
"""
|
| 217 |
+
# x is (batch, chan, frames)
|
| 218 |
+
batch, chan, frames = x.size()
|
| 219 |
+
assert x.ndim == 3
|
| 220 |
+
self.n_orig_frames = x.shape[-1]
|
| 221 |
+
unfolded = torch.nn.functional.unfold(
|
| 222 |
+
x.unsqueeze(-1),
|
| 223 |
+
kernel_size=(self.chunk_size, 1),
|
| 224 |
+
padding=(self.chunk_size, 0),
|
| 225 |
+
stride=(self.hop_size, 1),
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
return unfolded.reshape(
|
| 229 |
+
batch, chan, self.chunk_size, -1
|
| 230 |
+
) # (batch, chan, chunk_size, n_chunks)
|
| 231 |
+
|
| 232 |
+
def fold(self, x, output_size=None):
|
| 233 |
+
"""Folds back the spliced feature tensor.
|
| 234 |
+
|
| 235 |
+
Input shape (batch, channels, chunk_size, n_chunks) to original shape
|
| 236 |
+
(batch, channels, time) using overlap-add.
|
| 237 |
+
|
| 238 |
+
Args:
|
| 239 |
+
x: (:class:`torch.Tensor`): spliced feature tensor of shape
|
| 240 |
+
(batch, channels, chunk_size, n_chunks).
|
| 241 |
+
output_size: (int, optional): sequence length of original feature tensor.
|
| 242 |
+
If None, the original length cached by the previous call of `unfold`
|
| 243 |
+
will be used.
|
| 244 |
+
|
| 245 |
+
Returns:
|
| 246 |
+
x: (:class:`torch.Tensor`): feature tensor of shape (batch, channels, time).
|
| 247 |
+
|
| 248 |
+
.. note:: `fold` caches the original length of the pr
|
| 249 |
+
|
| 250 |
+
"""
|
| 251 |
+
output_size = output_size if output_size is not None else self.n_orig_frames
|
| 252 |
+
# x is (batch, chan, chunk_size, n_chunks)
|
| 253 |
+
batch, chan, chunk_size, n_chunks = x.size()
|
| 254 |
+
to_unfold = x.reshape(batch, chan * self.chunk_size, n_chunks)
|
| 255 |
+
x = torch.nn.functional.fold(
|
| 256 |
+
to_unfold,
|
| 257 |
+
(output_size, 1),
|
| 258 |
+
kernel_size=(self.chunk_size, 1),
|
| 259 |
+
padding=(self.chunk_size, 0),
|
| 260 |
+
stride=(self.hop_size, 1),
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
x /= self.chunk_size / self.hop_size
|
| 264 |
+
|
| 265 |
+
return x.reshape(batch, chan, self.n_orig_frames)
|
| 266 |
+
|
| 267 |
+
@staticmethod
|
| 268 |
+
def intra_process(x, module):
|
| 269 |
+
"""Performs intra-chunk processing.
|
| 270 |
+
|
| 271 |
+
Args:
|
| 272 |
+
x (:class:`torch.Tensor`): spliced feature tensor of shape
|
| 273 |
+
(batch, channels, chunk_size, n_chunks).
|
| 274 |
+
module (:class:`torch.nn.Module`): module one wish to apply to each chunk
|
| 275 |
+
of the spliced feature tensor.
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
Returns:
|
| 279 |
+
x (:class:`torch.Tensor`): processed spliced feature tensor of shape
|
| 280 |
+
(batch, channels, chunk_size, n_chunks).
|
| 281 |
+
|
| 282 |
+
.. note:: the module should have the channel first convention and accept
|
| 283 |
+
a 3D tensor of shape (batch, channels, time).
|
| 284 |
+
"""
|
| 285 |
+
|
| 286 |
+
# x is (batch, channels, chunk_size, n_chunks)
|
| 287 |
+
batch, channels, chunk_size, n_chunks = x.size()
|
| 288 |
+
# we reshape to batch*chunk_size, channels, n_chunks
|
| 289 |
+
x = x.transpose(1, -1).reshape(batch * n_chunks, chunk_size, channels).transpose(1, -1)
|
| 290 |
+
x = module(x)
|
| 291 |
+
x = x.reshape(batch, n_chunks, channels, chunk_size).transpose(1, -1).transpose(1, 2)
|
| 292 |
+
return x
|
| 293 |
+
|
| 294 |
+
@staticmethod
|
| 295 |
+
def inter_process(x, module):
|
| 296 |
+
"""Performs inter-chunk processing.
|
| 297 |
+
|
| 298 |
+
Args:
|
| 299 |
+
x (:class:`torch.Tensor`): spliced feature tensor of shape
|
| 300 |
+
(batch, channels, chunk_size, n_chunks).
|
| 301 |
+
module (:class:`torch.nn.Module`): module one wish to apply between
|
| 302 |
+
each chunk of the spliced feature tensor.
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
Returns:
|
| 306 |
+
x (:class:`torch.Tensor`): processed spliced feature tensor of shape
|
| 307 |
+
(batch, channels, chunk_size, n_chunks).
|
| 308 |
+
|
| 309 |
+
.. note:: the module should have the channel first convention and accept
|
| 310 |
+
a 3D tensor of shape (batch, channels, time).
|
| 311 |
+
"""
|
| 312 |
+
|
| 313 |
+
batch, channels, chunk_size, n_chunks = x.size()
|
| 314 |
+
x = x.transpose(1, 2).reshape(batch * chunk_size, channels, n_chunks)
|
| 315 |
+
x = module(x)
|
| 316 |
+
x = x.reshape(batch, chunk_size, channels, n_chunks).transpose(1, 2)
|
| 317 |
+
return x
|
DPTNet_eval/asteroid_test/filterbanks/__init__.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# from .analytic_free_fb import AnalyticFreeFB
|
| 2 |
+
from .free_fb import FreeFB
|
| 3 |
+
from .enc_dec import Filterbank, Encoder, Decoder
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def make_enc_dec(
|
| 7 |
+
fb_name,
|
| 8 |
+
n_filters,
|
| 9 |
+
kernel_size,
|
| 10 |
+
stride=None,
|
| 11 |
+
who_is_pinv=None,
|
| 12 |
+
padding=0,
|
| 13 |
+
output_padding=0,
|
| 14 |
+
**kwargs,
|
| 15 |
+
):
|
| 16 |
+
"""Creates congruent encoder and decoder from the same filterbank family.
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
fb_name (str, className): Filterbank family from which to make encoder
|
| 20 |
+
and decoder. To choose among [``'free'``, ``'analytic_free'``,
|
| 21 |
+
``'param_sinc'``, ``'stft'``]. Can also be a class defined in a
|
| 22 |
+
submodule in this subpackade (e.g. :class:`~.FreeFB`).
|
| 23 |
+
n_filters (int): Number of filters.
|
| 24 |
+
kernel_size (int): Length of the filters.
|
| 25 |
+
stride (int, optional): Stride of the convolution.
|
| 26 |
+
If None (default), set to ``kernel_size // 2``.
|
| 27 |
+
who_is_pinv (str, optional): If `None`, no pseudo-inverse filters will
|
| 28 |
+
be used. If string (among [``'encoder'``, ``'decoder'``]), decides
|
| 29 |
+
which of ``Encoder`` or ``Decoder`` will be the pseudo inverse of
|
| 30 |
+
the other one.
|
| 31 |
+
padding (int): Zero-padding added to both sides of the input.
|
| 32 |
+
Passed to Encoder and Decoder.
|
| 33 |
+
output_padding (int): Additional size added to one side of the output shape.
|
| 34 |
+
Passed to Decoder.
|
| 35 |
+
**kwargs: Arguments which will be passed to the filterbank class
|
| 36 |
+
additionally to the usual `n_filters`, `kernel_size` and `stride`.
|
| 37 |
+
Depends on the filterbank family.
|
| 38 |
+
Returns:
|
| 39 |
+
:class:`.Encoder`, :class:`.Decoder`
|
| 40 |
+
"""
|
| 41 |
+
fb_class = get(fb_name)
|
| 42 |
+
|
| 43 |
+
if who_is_pinv in ["dec", "decoder"]:
|
| 44 |
+
fb = fb_class(n_filters, kernel_size, stride=stride, **kwargs)
|
| 45 |
+
enc = Encoder(fb, padding=padding)
|
| 46 |
+
# Decoder filterbank is pseudo inverse of encoder filterbank.
|
| 47 |
+
dec = Decoder.pinv_of(fb)
|
| 48 |
+
elif who_is_pinv in ["enc", "encoder"]:
|
| 49 |
+
fb = fb_class(n_filters, kernel_size, stride=stride, **kwargs)
|
| 50 |
+
dec = Decoder(fb, padding=padding, output_padding=output_padding)
|
| 51 |
+
# Encoder filterbank is pseudo inverse of decoder filterbank.
|
| 52 |
+
enc = Encoder.pinv_of(fb)
|
| 53 |
+
else:
|
| 54 |
+
fb = fb_class(n_filters, kernel_size, stride=stride, **kwargs)
|
| 55 |
+
enc = Encoder(fb, padding=padding)
|
| 56 |
+
# Filters between encoder and decoder should not be shared.
|
| 57 |
+
fb = fb_class(n_filters, kernel_size, stride=stride, **kwargs)
|
| 58 |
+
dec = Decoder(fb, padding=padding, output_padding=output_padding)
|
| 59 |
+
return enc, dec
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def register_filterbank(custom_fb):
|
| 63 |
+
"""Register a custom filterbank, gettable with `filterbanks.get`.
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
custom_fb: Custom filterbank to register.
|
| 67 |
+
|
| 68 |
+
"""
|
| 69 |
+
if custom_fb.__name__ in globals().keys() or custom_fb.__name__.lower() in globals().keys():
|
| 70 |
+
raise ValueError(f"Filterbank {custom_fb.__name__} already exists. Choose another name.")
|
| 71 |
+
globals().update({custom_fb.__name__: custom_fb})
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def get(identifier):
|
| 75 |
+
"""Returns a filterbank class from a string. Returns its input if it
|
| 76 |
+
is callable (already a :class:`.Filterbank` for example).
|
| 77 |
+
|
| 78 |
+
Args:
|
| 79 |
+
identifier (str or Callable or None): the filterbank identifier.
|
| 80 |
+
|
| 81 |
+
Returns:
|
| 82 |
+
:class:`.Filterbank` or None
|
| 83 |
+
"""
|
| 84 |
+
if identifier is None:
|
| 85 |
+
return None
|
| 86 |
+
elif callable(identifier):
|
| 87 |
+
return identifier
|
| 88 |
+
elif isinstance(identifier, str):
|
| 89 |
+
cls = globals().get(identifier)
|
| 90 |
+
if cls is None:
|
| 91 |
+
raise ValueError("Could not interpret filterbank identifier: " + str(identifier))
|
| 92 |
+
return cls
|
| 93 |
+
else:
|
| 94 |
+
raise ValueError("Could not interpret filterbank identifier: " + str(identifier))
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
# Aliases.
|
| 98 |
+
free = FreeFB
|
| 99 |
+
|
| 100 |
+
# For the docs
|
| 101 |
+
__all__ = [
|
| 102 |
+
"Filterbank",
|
| 103 |
+
"Encoder",
|
| 104 |
+
"Decoder",
|
| 105 |
+
"FreeFB",
|
| 106 |
+
"make_enc_dec",
|
| 107 |
+
]
|
DPTNet_eval/asteroid_test/filterbanks/enc_dec.py
ADDED
|
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
import torch
|
| 3 |
+
from torch import nn
|
| 4 |
+
from torch.nn import functional as F
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class Filterbank(nn.Module):
|
| 8 |
+
"""Base Filterbank class.
|
| 9 |
+
Each subclass has to implement a `filters` property.
|
| 10 |
+
|
| 11 |
+
Args:
|
| 12 |
+
n_filters (int): Number of filters.
|
| 13 |
+
kernel_size (int): Length of the filters.
|
| 14 |
+
stride (int, optional): Stride of the conv or transposed conv. (Hop size).
|
| 15 |
+
If None (default), set to ``kernel_size // 2``.
|
| 16 |
+
|
| 17 |
+
Attributes:
|
| 18 |
+
n_feats_out (int): Number of output filters.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def __init__(self, n_filters, kernel_size, stride=None):
|
| 22 |
+
super(Filterbank, self).__init__()
|
| 23 |
+
self.n_filters = n_filters
|
| 24 |
+
self.kernel_size = kernel_size
|
| 25 |
+
self.stride = stride if stride else self.kernel_size // 2
|
| 26 |
+
# If not specified otherwise in the filterbank's init, output
|
| 27 |
+
# number of features is equal to number of required filters.
|
| 28 |
+
self.n_feats_out = n_filters
|
| 29 |
+
|
| 30 |
+
@property
|
| 31 |
+
def filters(self):
|
| 32 |
+
""" Abstract method for filters. """
|
| 33 |
+
raise NotImplementedError
|
| 34 |
+
|
| 35 |
+
def get_config(self):
|
| 36 |
+
""" Returns dictionary of arguments to re-instantiate the class. """
|
| 37 |
+
config = {
|
| 38 |
+
"fb_name": self.__class__.__name__,
|
| 39 |
+
"n_filters": self.n_filters,
|
| 40 |
+
"kernel_size": self.kernel_size,
|
| 41 |
+
"stride": self.stride,
|
| 42 |
+
}
|
| 43 |
+
return config
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class _EncDec(nn.Module):
|
| 47 |
+
"""Base private class for Encoder and Decoder.
|
| 48 |
+
|
| 49 |
+
Common parameters and methods.
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
filterbank (:class:`Filterbank`): Filterbank instance. The filterbank
|
| 53 |
+
to use as an encoder or a decoder.
|
| 54 |
+
is_pinv (bool): Whether to be the pseudo inverse of filterbank.
|
| 55 |
+
|
| 56 |
+
Attributes:
|
| 57 |
+
filterbank (:class:`Filterbank`)
|
| 58 |
+
stride (int)
|
| 59 |
+
is_pinv (bool)
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
def __init__(self, filterbank, is_pinv=False):
|
| 63 |
+
super(_EncDec, self).__init__()
|
| 64 |
+
self.filterbank = filterbank
|
| 65 |
+
self.stride = self.filterbank.stride
|
| 66 |
+
self.is_pinv = is_pinv
|
| 67 |
+
|
| 68 |
+
@property
|
| 69 |
+
def filters(self):
|
| 70 |
+
return self.filterbank.filters
|
| 71 |
+
|
| 72 |
+
def compute_filter_pinv(self, filters):
|
| 73 |
+
""" Computes pseudo inverse filterbank of given filters."""
|
| 74 |
+
scale = self.filterbank.stride / self.filterbank.kernel_size
|
| 75 |
+
shape = filters.shape
|
| 76 |
+
ifilt = torch.pinverse(filters.squeeze()).transpose(-1, -2).view(shape)
|
| 77 |
+
# Compensate for the overlap-add.
|
| 78 |
+
return ifilt * scale
|
| 79 |
+
|
| 80 |
+
def get_filters(self):
|
| 81 |
+
""" Returns filters or pinv filters depending on `is_pinv` attribute """
|
| 82 |
+
if self.is_pinv:
|
| 83 |
+
return self.compute_filter_pinv(self.filters)
|
| 84 |
+
else:
|
| 85 |
+
return self.filters
|
| 86 |
+
|
| 87 |
+
def get_config(self):
|
| 88 |
+
""" Returns dictionary of arguments to re-instantiate the class."""
|
| 89 |
+
config = {"is_pinv": self.is_pinv}
|
| 90 |
+
base_config = self.filterbank.get_config()
|
| 91 |
+
return dict(list(base_config.items()) + list(config.items()))
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class Encoder(_EncDec):
|
| 95 |
+
"""Encoder class.
|
| 96 |
+
|
| 97 |
+
Add encoding methods to Filterbank classes.
|
| 98 |
+
Not intended to be subclassed.
|
| 99 |
+
|
| 100 |
+
Args:
|
| 101 |
+
filterbank (:class:`Filterbank`): The filterbank to use
|
| 102 |
+
as an encoder.
|
| 103 |
+
is_pinv (bool): Whether to be the pseudo inverse of filterbank.
|
| 104 |
+
as_conv1d (bool): Whether to behave like nn.Conv1d.
|
| 105 |
+
If True (default), forwarding input with shape (batch, 1, time)
|
| 106 |
+
will output a tensor of shape (batch, freq, conv_time).
|
| 107 |
+
If False, will output a tensor of shape (batch, 1, freq, conv_time).
|
| 108 |
+
padding (int): Zero-padding added to both sides of the input.
|
| 109 |
+
|
| 110 |
+
"""
|
| 111 |
+
|
| 112 |
+
def __init__(self, filterbank, is_pinv=False, as_conv1d=True, padding=0):
|
| 113 |
+
super(Encoder, self).__init__(filterbank, is_pinv=is_pinv)
|
| 114 |
+
self.as_conv1d = as_conv1d
|
| 115 |
+
self.n_feats_out = self.filterbank.n_feats_out
|
| 116 |
+
self.padding = padding
|
| 117 |
+
|
| 118 |
+
@classmethod
|
| 119 |
+
def pinv_of(cls, filterbank, **kwargs):
|
| 120 |
+
"""Returns an :class:`~.Encoder`, pseudo inverse of a
|
| 121 |
+
:class:`~.Filterbank` or :class:`~.Decoder`."""
|
| 122 |
+
if isinstance(filterbank, Filterbank):
|
| 123 |
+
return cls(filterbank, is_pinv=True, **kwargs)
|
| 124 |
+
elif isinstance(filterbank, Decoder):
|
| 125 |
+
return cls(filterbank.filterbank, is_pinv=True, **kwargs)
|
| 126 |
+
|
| 127 |
+
def forward(self, waveform):
|
| 128 |
+
"""Convolve input waveform with the filters from a filterbank.
|
| 129 |
+
Args:
|
| 130 |
+
waveform (:class:`torch.Tensor`): any tensor with samples along the
|
| 131 |
+
last dimension. The waveform representation with and
|
| 132 |
+
batch/channel etc.. dimension.
|
| 133 |
+
Returns:
|
| 134 |
+
:class:`torch.Tensor`: The corresponding TF domain signal.
|
| 135 |
+
|
| 136 |
+
Shapes:
|
| 137 |
+
>>> (time, ) --> (freq, conv_time)
|
| 138 |
+
>>> (batch, time) --> (batch, freq, conv_time) # Avoid
|
| 139 |
+
>>> if as_conv1d:
|
| 140 |
+
>>> (batch, 1, time) --> (batch, freq, conv_time)
|
| 141 |
+
>>> (batch, chan, time) --> (batch, chan, freq, conv_time)
|
| 142 |
+
>>> else:
|
| 143 |
+
>>> (batch, chan, time) --> (batch, chan, freq, conv_time)
|
| 144 |
+
>>> (batch, any, dim, time) --> (batch, any, dim, freq, conv_time)
|
| 145 |
+
"""
|
| 146 |
+
filters = self.get_filters()
|
| 147 |
+
if waveform.ndim == 1:
|
| 148 |
+
# Assumes 1D input with shape (time,)
|
| 149 |
+
# Output will be (freq, conv_time)
|
| 150 |
+
return F.conv1d(
|
| 151 |
+
waveform[None, None], filters, stride=self.stride, padding=self.padding
|
| 152 |
+
).squeeze()
|
| 153 |
+
elif waveform.ndim == 2:
|
| 154 |
+
# Assume 2D input with shape (batch or channels, time)
|
| 155 |
+
# Output will be (batch or channels, freq, conv_time)
|
| 156 |
+
warnings.warn(
|
| 157 |
+
"Input tensor was 2D. Applying the corresponding "
|
| 158 |
+
"Decoder to the current output will result in a 3D "
|
| 159 |
+
"tensor. This behaviours was introduced to match "
|
| 160 |
+
"Conv1D and ConvTranspose1D, please use 3D inputs "
|
| 161 |
+
"to avoid it. For example, this can be done with "
|
| 162 |
+
"input_tensor.unsqueeze(1)."
|
| 163 |
+
)
|
| 164 |
+
return F.conv1d(
|
| 165 |
+
waveform.unsqueeze(1), filters, stride=self.stride, padding=self.padding
|
| 166 |
+
)
|
| 167 |
+
elif waveform.ndim == 3:
|
| 168 |
+
batch, channels, time_len = waveform.shape
|
| 169 |
+
if channels == 1 and self.as_conv1d:
|
| 170 |
+
# That's the common single channel case (batch, 1, time)
|
| 171 |
+
# Output will be (batch, freq, stft_time), behaves as Conv1D
|
| 172 |
+
return F.conv1d(waveform, filters, stride=self.stride, padding=self.padding)
|
| 173 |
+
else:
|
| 174 |
+
# Return batched convolution, input is (batch, 3, time),
|
| 175 |
+
# output will be (batch, 3, freq, conv_time).
|
| 176 |
+
# Useful for multichannel transforms
|
| 177 |
+
# If as_conv1d is false, (batch, 1, time) will output
|
| 178 |
+
# (batch, 1, freq, conv_time), useful for consistency.
|
| 179 |
+
return self.batch_1d_conv(waveform, filters)
|
| 180 |
+
else: # waveform.ndim > 3
|
| 181 |
+
# This is to compute "multi"multichannel convolution.
|
| 182 |
+
# Input can be (*, time), output will be (*, freq, conv_time)
|
| 183 |
+
return self.batch_1d_conv(waveform, filters)
|
| 184 |
+
|
| 185 |
+
def batch_1d_conv(self, inp, filters):
|
| 186 |
+
# Here we perform multichannel / multi-source convolution. Ou
|
| 187 |
+
# Output should be (batch, channels, freq, conv_time)
|
| 188 |
+
batched_conv = F.conv1d(
|
| 189 |
+
inp.view(-1, 1, inp.shape[-1]), filters, stride=self.stride, padding=self.padding
|
| 190 |
+
)
|
| 191 |
+
output_shape = inp.shape[:-1] + batched_conv.shape[-2:]
|
| 192 |
+
return batched_conv.view(output_shape)
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
class Decoder(_EncDec):
|
| 196 |
+
"""Decoder class.
|
| 197 |
+
|
| 198 |
+
Add decoding methods to Filterbank classes.
|
| 199 |
+
Not intended to be subclassed.
|
| 200 |
+
|
| 201 |
+
Args:
|
| 202 |
+
filterbank (:class:`Filterbank`): The filterbank to use as an decoder.
|
| 203 |
+
is_pinv (bool): Whether to be the pseudo inverse of filterbank.
|
| 204 |
+
padding (int): Zero-padding added to both sides of the input.
|
| 205 |
+
output_padding (int): Additional size added to one side of the
|
| 206 |
+
output shape.
|
| 207 |
+
|
| 208 |
+
Notes
|
| 209 |
+
`padding` and `output_padding` arguments are directly passed to
|
| 210 |
+
F.conv_transpose1d.
|
| 211 |
+
"""
|
| 212 |
+
|
| 213 |
+
def __init__(self, filterbank, is_pinv=False, padding=0, output_padding=0):
|
| 214 |
+
super().__init__(filterbank, is_pinv=is_pinv)
|
| 215 |
+
self.padding = padding
|
| 216 |
+
self.output_padding = output_padding
|
| 217 |
+
|
| 218 |
+
@classmethod
|
| 219 |
+
def pinv_of(cls, filterbank):
|
| 220 |
+
""" Returns an Decoder, pseudo inverse of a filterbank or Encoder."""
|
| 221 |
+
if isinstance(filterbank, Filterbank):
|
| 222 |
+
return cls(filterbank, is_pinv=True)
|
| 223 |
+
elif isinstance(filterbank, Encoder):
|
| 224 |
+
return cls(filterbank.filterbank, is_pinv=True)
|
| 225 |
+
|
| 226 |
+
def forward(self, spec):
|
| 227 |
+
"""Applies transposed convolution to a TF representation.
|
| 228 |
+
|
| 229 |
+
This is equivalent to overlap-add.
|
| 230 |
+
|
| 231 |
+
Args:
|
| 232 |
+
spec (:class:`torch.Tensor`): 3D or 4D Tensor. The TF
|
| 233 |
+
representation. (Output of :func:`Encoder.forward`).
|
| 234 |
+
Returns:
|
| 235 |
+
:class:`torch.Tensor`: The corresponding time domain signal.
|
| 236 |
+
"""
|
| 237 |
+
filters = self.get_filters()
|
| 238 |
+
if spec.ndim == 2:
|
| 239 |
+
# Input is (freq, conv_time), output is (time)
|
| 240 |
+
return F.conv_transpose1d(
|
| 241 |
+
spec.unsqueeze(0),
|
| 242 |
+
filters,
|
| 243 |
+
stride=self.stride,
|
| 244 |
+
padding=self.padding,
|
| 245 |
+
output_padding=self.output_padding,
|
| 246 |
+
).squeeze()
|
| 247 |
+
if spec.ndim == 3:
|
| 248 |
+
# Input is (batch, freq, conv_time), output is (batch, 1, time)
|
| 249 |
+
return F.conv_transpose1d(
|
| 250 |
+
spec,
|
| 251 |
+
filters,
|
| 252 |
+
stride=self.stride,
|
| 253 |
+
padding=self.padding,
|
| 254 |
+
output_padding=self.output_padding,
|
| 255 |
+
)
|
| 256 |
+
elif spec.ndim > 3:
|
| 257 |
+
# Multiply all the left dimensions together and group them in the
|
| 258 |
+
# batch. Make the convolution and restore.
|
| 259 |
+
view_as = (-1,) + spec.shape[-2:]
|
| 260 |
+
out = F.conv_transpose1d(
|
| 261 |
+
spec.view(view_as),
|
| 262 |
+
filters,
|
| 263 |
+
stride=self.stride,
|
| 264 |
+
padding=self.padding,
|
| 265 |
+
output_padding=self.output_padding,
|
| 266 |
+
)
|
| 267 |
+
return out.view(spec.shape[:-2] + (-1,))
|
DPTNet_eval/asteroid_test/filterbanks/free_fb.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from .enc_dec import Filterbank
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class FreeFB(Filterbank):
|
| 7 |
+
"""Free filterbank without any constraints. Equivalent to
|
| 8 |
+
:class:`nn.Conv1d`.
|
| 9 |
+
|
| 10 |
+
Args:
|
| 11 |
+
n_filters (int): Number of filters.
|
| 12 |
+
kernel_size (int): Length of the filters.
|
| 13 |
+
stride (int, optional): Stride of the convolution.
|
| 14 |
+
If None (default), set to ``kernel_size // 2``.
|
| 15 |
+
|
| 16 |
+
Attributes:
|
| 17 |
+
n_feats_out (int): Number of output filters.
|
| 18 |
+
|
| 19 |
+
References:
|
| 20 |
+
[1] : "Filterbank design for end-to-end speech separation".
|
| 21 |
+
Submitted to ICASSP 2020. Manuel Pariente, Samuele Cornell,
|
| 22 |
+
Antoine Deleforge, Emmanuel Vincent.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
def __init__(self, n_filters, kernel_size, stride=None, **kwargs):
|
| 26 |
+
super(FreeFB, self).__init__(n_filters, kernel_size, stride=stride)
|
| 27 |
+
self._filters = nn.Parameter(torch.ones(n_filters, 1, kernel_size))
|
| 28 |
+
for p in self.parameters():
|
| 29 |
+
nn.init.xavier_normal_(p)
|
| 30 |
+
|
| 31 |
+
@property
|
| 32 |
+
def filters(self):
|
| 33 |
+
return self._filters
|
DPTNet_eval/asteroid_test/masknn/__init__.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# from .convolutional import TDConvNet, TDConvNetpp, SuDORMRF, SuDORMRFImproved
|
| 2 |
+
# from .recurrent import DPRNN, LSTMMasker
|
| 3 |
+
from .attention import DPTransformer
|
| 4 |
+
|
| 5 |
+
__all__ = [
|
| 6 |
+
# "TDConvNet",
|
| 7 |
+
# "DPRNN",
|
| 8 |
+
"DPTransformer",
|
| 9 |
+
# "LSTMMasker",
|
| 10 |
+
# "SuDORMRF",
|
| 11 |
+
# "SuDORMRFImproved",
|
| 12 |
+
]
|
DPTNet_eval/asteroid_test/masknn/activations.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from functools import partial
|
| 2 |
+
import torch
|
| 3 |
+
from torch import nn
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class Swish(nn.Module):
|
| 7 |
+
def __init__(self):
|
| 8 |
+
super(Swish, self).__init__()
|
| 9 |
+
|
| 10 |
+
def forward(self, x):
|
| 11 |
+
return x * torch.sigmoid(x)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def linear():
|
| 15 |
+
return nn.Identity()
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def relu():
|
| 19 |
+
return nn.ReLU()
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def prelu():
|
| 23 |
+
return nn.PReLU()
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def leaky_relu():
|
| 27 |
+
return nn.LeakyReLU()
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def sigmoid():
|
| 31 |
+
return nn.Sigmoid()
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def softmax(dim=None):
|
| 35 |
+
return nn.Softmax(dim=dim)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def tanh():
|
| 39 |
+
return nn.Tanh()
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def gelu():
|
| 43 |
+
return nn.GELU()
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def swish():
|
| 47 |
+
return Swish()
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def register_activation(custom_act):
|
| 51 |
+
"""Register a custom activation, gettable with `activation.get`.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
custom_act: Custom activation function to register.
|
| 55 |
+
|
| 56 |
+
"""
|
| 57 |
+
if custom_act.__name__ in globals().keys() or custom_act.__name__.lower() in globals().keys():
|
| 58 |
+
raise ValueError(f"Activation {custom_act.__name__} already exists. Choose another name.")
|
| 59 |
+
globals().update({custom_act.__name__: custom_act})
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def get(identifier):
|
| 63 |
+
"""Returns an activation function from a string. Returns its input if it
|
| 64 |
+
is callable (already an activation for example).
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
identifier (str or Callable or None): the activation identifier.
|
| 68 |
+
|
| 69 |
+
Returns:
|
| 70 |
+
:class:`nn.Module` or None
|
| 71 |
+
"""
|
| 72 |
+
if identifier is None:
|
| 73 |
+
return None
|
| 74 |
+
elif callable(identifier):
|
| 75 |
+
return identifier
|
| 76 |
+
elif isinstance(identifier, str):
|
| 77 |
+
cls = globals().get(identifier)
|
| 78 |
+
if cls is None:
|
| 79 |
+
raise ValueError("Could not interpret activation identifier: " + str(identifier))
|
| 80 |
+
return cls
|
| 81 |
+
else:
|
| 82 |
+
raise ValueError("Could not interpret activation identifier: " + str(identifier))
|
DPTNet_eval/asteroid_test/masknn/attention.py
ADDED
|
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from math import ceil
|
| 2 |
+
import warnings
|
| 3 |
+
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from torch.nn.modules.activation import MultiheadAttention
|
| 6 |
+
from ..masknn import activations, norms
|
| 7 |
+
import torch
|
| 8 |
+
from ..dsp.overlap_add import DualPathProcessing
|
| 9 |
+
|
| 10 |
+
import inspect
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class ImprovedTransformedLayer(nn.Module):
|
| 14 |
+
"""
|
| 15 |
+
Improved Transformer module as used in [1].
|
| 16 |
+
It is Multi-Head self-attention followed by LSTM, activation and linear projection layer.
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
embed_dim (int): Number of input channels.
|
| 20 |
+
n_heads (int): Number of attention heads.
|
| 21 |
+
dim_ff (int): Number of neurons in the RNNs cell state.
|
| 22 |
+
Defaults to 256. RNN here replaces standard FF linear layer in plain Transformer.
|
| 23 |
+
dropout (float, optional): Dropout ratio, must be in [0,1].
|
| 24 |
+
activation (str, optional): activation function applied at the output of RNN.
|
| 25 |
+
bidirectional (bool, optional): True for bidirectional Inter-Chunk RNN
|
| 26 |
+
(Intra-Chunk is always bidirectional).
|
| 27 |
+
norm_type (str, optional): Type of normalization to use.
|
| 28 |
+
|
| 29 |
+
References:
|
| 30 |
+
[1] Chen, Jingjing, Qirong Mao, and Dong Liu.
|
| 31 |
+
"Dual-Path Transformer Network: Direct Context-Aware Modeling for End-to-End Monaural Speech Separation."
|
| 32 |
+
arXiv preprint arXiv:2007.13975 (2020).
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
def __init__(
|
| 36 |
+
self,
|
| 37 |
+
embed_dim,
|
| 38 |
+
n_heads,
|
| 39 |
+
dim_ff,
|
| 40 |
+
dropout=0.0,
|
| 41 |
+
activation="relu",
|
| 42 |
+
bidirectional=True,
|
| 43 |
+
norm="gLN",
|
| 44 |
+
):
|
| 45 |
+
super(ImprovedTransformedLayer, self).__init__()
|
| 46 |
+
|
| 47 |
+
self.mha = MultiheadAttention(embed_dim, n_heads, dropout=dropout)
|
| 48 |
+
# self.linear_first = nn.Linear(embed_dim, 2 * dim_ff) # Added by Kay. 20201119
|
| 49 |
+
self.dropout = nn.Dropout(dropout)
|
| 50 |
+
self.recurrent = nn.LSTM(embed_dim, dim_ff, bidirectional=bidirectional, batch_first=True)
|
| 51 |
+
ff_inner_dim = 2 * dim_ff if bidirectional else dim_ff
|
| 52 |
+
self.linear = nn.Linear(ff_inner_dim, embed_dim)
|
| 53 |
+
self.activation = activations.get(activation)()
|
| 54 |
+
self.norm_mha = norms.get(norm)(embed_dim)
|
| 55 |
+
self.norm_ff = norms.get(norm)(embed_dim)
|
| 56 |
+
|
| 57 |
+
def forward(self, x):
|
| 58 |
+
tomha = x.permute(2, 0, 1)
|
| 59 |
+
# x is batch, channels, seq_len
|
| 60 |
+
# mha is seq_len, batch, channels
|
| 61 |
+
# self-attention is applied
|
| 62 |
+
out = self.mha(tomha, tomha, tomha)[0]
|
| 63 |
+
x = self.dropout(out.permute(1, 2, 0)) + x
|
| 64 |
+
x = self.norm_mha(x)
|
| 65 |
+
|
| 66 |
+
# lstm is applied
|
| 67 |
+
out = self.linear(self.dropout(self.activation(self.recurrent(x.transpose(1, -1))[0])))
|
| 68 |
+
x = self.dropout(out.transpose(1, -1)) + x
|
| 69 |
+
return self.norm_ff(x)
|
| 70 |
+
|
| 71 |
+
''' version 0.3.4
|
| 72 |
+
def forward(self, x):
|
| 73 |
+
x = x.transpose(1, -1)
|
| 74 |
+
# x is batch, seq_len, channels
|
| 75 |
+
# self-attention is applied
|
| 76 |
+
out = self.mha(x, x, x)[0]
|
| 77 |
+
x = self.dropout(out) + x
|
| 78 |
+
x = self.norm_mha(x.transpose(1, -1)).transpose(1, -1)
|
| 79 |
+
|
| 80 |
+
# lstm is applied
|
| 81 |
+
out = self.linear(self.dropout(self.activation(self.recurrent(x)[0])))
|
| 82 |
+
# out = self.linear(self.dropout(self.activation(self.linear_first(x)[0])))
|
| 83 |
+
x = self.dropout(out) + x
|
| 84 |
+
return self.norm_ff(x.transpose(1, -1))
|
| 85 |
+
'''
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class DPTransformer(nn.Module):
|
| 89 |
+
"""Dual-path Transformer introduced in [1].
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
in_chan (int): Number of input filters.
|
| 93 |
+
n_src (int): Number of masks to estimate.
|
| 94 |
+
n_heads (int): Number of attention heads.
|
| 95 |
+
ff_hid (int): Number of neurons in the RNNs cell state.
|
| 96 |
+
Defaults to 256.
|
| 97 |
+
chunk_size (int): window size of overlap and add processing.
|
| 98 |
+
Defaults to 100.
|
| 99 |
+
hop_size (int or None): hop size (stride) of overlap and add processing.
|
| 100 |
+
Default to `chunk_size // 2` (50% overlap).
|
| 101 |
+
n_repeats (int): Number of repeats. Defaults to 6.
|
| 102 |
+
norm_type (str, optional): Type of normalization to use.
|
| 103 |
+
ff_activation (str, optional): activation function applied at the output of RNN.
|
| 104 |
+
mask_act (str, optional): Which non-linear function to generate mask.
|
| 105 |
+
bidirectional (bool, optional): True for bidirectional Inter-Chunk RNN
|
| 106 |
+
(Intra-Chunk is always bidirectional).
|
| 107 |
+
dropout (float, optional): Dropout ratio, must be in [0,1].
|
| 108 |
+
|
| 109 |
+
References
|
| 110 |
+
[1] Chen, Jingjing, Qirong Mao, and Dong Liu. "Dual-Path Transformer
|
| 111 |
+
Network: Direct Context-Aware Modeling for End-to-End Monaural Speech Separation."
|
| 112 |
+
arXiv (2020).
|
| 113 |
+
"""
|
| 114 |
+
|
| 115 |
+
def __init__(
|
| 116 |
+
self,
|
| 117 |
+
in_chan,
|
| 118 |
+
n_src,
|
| 119 |
+
n_heads=4,
|
| 120 |
+
ff_hid=256,
|
| 121 |
+
chunk_size=100,
|
| 122 |
+
hop_size=None,
|
| 123 |
+
n_repeats=6,
|
| 124 |
+
norm_type="gLN",
|
| 125 |
+
ff_activation="relu",
|
| 126 |
+
mask_act="relu",
|
| 127 |
+
bidirectional=True,
|
| 128 |
+
dropout=0,
|
| 129 |
+
):
|
| 130 |
+
super(DPTransformer, self).__init__()
|
| 131 |
+
self.in_chan = in_chan
|
| 132 |
+
self.n_src = n_src
|
| 133 |
+
self.n_heads = n_heads
|
| 134 |
+
self.ff_hid = ff_hid
|
| 135 |
+
self.chunk_size = chunk_size
|
| 136 |
+
hop_size = hop_size if hop_size is not None else chunk_size // 2
|
| 137 |
+
self.hop_size = hop_size
|
| 138 |
+
self.n_repeats = n_repeats
|
| 139 |
+
self.n_src = n_src
|
| 140 |
+
self.norm_type = norm_type
|
| 141 |
+
self.ff_activation = ff_activation
|
| 142 |
+
self.mask_act = mask_act
|
| 143 |
+
self.bidirectional = bidirectional
|
| 144 |
+
self.dropout = dropout
|
| 145 |
+
|
| 146 |
+
# version 0.3.4
|
| 147 |
+
# self.in_norm = norms.get(norm_type)(in_chan)
|
| 148 |
+
self.mha_in_dim = ceil(self.in_chan / self.n_heads) * self.n_heads
|
| 149 |
+
if self.in_chan % self.n_heads != 0:
|
| 150 |
+
warnings.warn(
|
| 151 |
+
f"DPTransformer input dim ({self.in_chan}) is not a multiple of the number of "
|
| 152 |
+
f"heads ({self.n_heads}). Adding extra linear layer at input to accomodate "
|
| 153 |
+
f"(size [{self.in_chan} x {self.mha_in_dim}])"
|
| 154 |
+
)
|
| 155 |
+
self.input_layer = nn.Linear(self.in_chan, self.mha_in_dim)
|
| 156 |
+
else:
|
| 157 |
+
self.input_layer = None
|
| 158 |
+
|
| 159 |
+
self.in_norm = norms.get(norm_type)(self.mha_in_dim)
|
| 160 |
+
self.ola = DualPathProcessing(self.chunk_size, self.hop_size)
|
| 161 |
+
|
| 162 |
+
# Succession of DPRNNBlocks.
|
| 163 |
+
self.layers = nn.ModuleList([])
|
| 164 |
+
for x in range(self.n_repeats):
|
| 165 |
+
self.layers.append(
|
| 166 |
+
nn.ModuleList(
|
| 167 |
+
[
|
| 168 |
+
ImprovedTransformedLayer(
|
| 169 |
+
self.mha_in_dim,
|
| 170 |
+
self.n_heads,
|
| 171 |
+
self.ff_hid,
|
| 172 |
+
self.dropout,
|
| 173 |
+
self.ff_activation,
|
| 174 |
+
True,
|
| 175 |
+
self.norm_type,
|
| 176 |
+
),
|
| 177 |
+
ImprovedTransformedLayer(
|
| 178 |
+
self.mha_in_dim,
|
| 179 |
+
self.n_heads,
|
| 180 |
+
self.ff_hid,
|
| 181 |
+
self.dropout,
|
| 182 |
+
self.ff_activation,
|
| 183 |
+
self.bidirectional,
|
| 184 |
+
self.norm_type,
|
| 185 |
+
),
|
| 186 |
+
]
|
| 187 |
+
)
|
| 188 |
+
)
|
| 189 |
+
net_out_conv = nn.Conv2d(self.mha_in_dim, n_src * self.in_chan, 1)
|
| 190 |
+
self.first_out = nn.Sequential(nn.PReLU(), net_out_conv)
|
| 191 |
+
# Gating and masking in 2D space (after fold)
|
| 192 |
+
self.net_out = nn.Sequential(nn.Conv1d(self.in_chan, self.in_chan, 1), nn.Tanh())
|
| 193 |
+
self.net_gate = nn.Sequential(nn.Conv1d(self.in_chan, self.in_chan, 1), nn.Sigmoid())
|
| 194 |
+
|
| 195 |
+
# Get activation function.
|
| 196 |
+
mask_nl_class = activations.get(mask_act)
|
| 197 |
+
# For softmax, feed the source dimension.
|
| 198 |
+
if has_arg(mask_nl_class, "dim"):
|
| 199 |
+
self.output_act = mask_nl_class(dim=1)
|
| 200 |
+
else:
|
| 201 |
+
self.output_act = mask_nl_class()
|
| 202 |
+
|
| 203 |
+
def forward(self, mixture_w):
|
| 204 |
+
r"""Forward.
|
| 205 |
+
|
| 206 |
+
Args:
|
| 207 |
+
mixture_w (:class:`torch.Tensor`): Tensor of shape $(batch, nfilters, nframes)$
|
| 208 |
+
|
| 209 |
+
Returns:
|
| 210 |
+
:class:`torch.Tensor`: estimated mask of shape $(batch, nsrc, nfilters, nframes)$
|
| 211 |
+
"""
|
| 212 |
+
if self.input_layer is not None:
|
| 213 |
+
mixture_w = self.input_layer(mixture_w.transpose(1, 2)).transpose(1, 2)
|
| 214 |
+
mixture_w = self.in_norm(mixture_w) # [batch, bn_chan, n_frames]
|
| 215 |
+
n_orig_frames = mixture_w.shape[-1]
|
| 216 |
+
|
| 217 |
+
mixture_w = self.ola.unfold(mixture_w)
|
| 218 |
+
batch, n_filters, self.chunk_size, n_chunks = mixture_w.size()
|
| 219 |
+
|
| 220 |
+
for layer_idx in range(len(self.layers)):
|
| 221 |
+
intra, inter = self.layers[layer_idx]
|
| 222 |
+
mixture_w = self.ola.intra_process(mixture_w, intra)
|
| 223 |
+
mixture_w = self.ola.inter_process(mixture_w, inter)
|
| 224 |
+
|
| 225 |
+
output = self.first_out(mixture_w)
|
| 226 |
+
output = output.reshape(batch * self.n_src, self.in_chan, self.chunk_size, n_chunks)
|
| 227 |
+
output = self.ola.fold(output, output_size=n_orig_frames)
|
| 228 |
+
|
| 229 |
+
output = self.net_out(output) * self.net_gate(output)
|
| 230 |
+
# Compute mask
|
| 231 |
+
output = output.reshape(batch, self.n_src, self.in_chan, -1)
|
| 232 |
+
est_mask = self.output_act(output)
|
| 233 |
+
return est_mask
|
| 234 |
+
|
| 235 |
+
def get_config(self):
|
| 236 |
+
config = {
|
| 237 |
+
"in_chan": self.in_chan,
|
| 238 |
+
"ff_hid": self.ff_hid,
|
| 239 |
+
"n_heads": self.n_heads,
|
| 240 |
+
"chunk_size": self.chunk_size,
|
| 241 |
+
"hop_size": self.hop_size,
|
| 242 |
+
"n_repeats": self.n_repeats,
|
| 243 |
+
"n_src": self.n_src,
|
| 244 |
+
"norm_type": self.norm_type,
|
| 245 |
+
"ff_activation": self.ff_activation,
|
| 246 |
+
"mask_act": self.mask_act,
|
| 247 |
+
"bidirectional": self.bidirectional,
|
| 248 |
+
"dropout": self.dropout,
|
| 249 |
+
}
|
| 250 |
+
return config
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
def has_arg(fn, name):
|
| 254 |
+
"""Checks if a callable accepts a given keyword argument.
|
| 255 |
+
|
| 256 |
+
Args:
|
| 257 |
+
fn (callable): Callable to inspect.
|
| 258 |
+
name (str): Check if `fn` can be called with `name` as a keyword
|
| 259 |
+
argument.
|
| 260 |
+
|
| 261 |
+
Returns:
|
| 262 |
+
bool: whether `fn` accepts a `name` keyword argument.
|
| 263 |
+
"""
|
| 264 |
+
signature = inspect.signature(fn)
|
| 265 |
+
parameter = signature.parameters.get(name)
|
| 266 |
+
if parameter is None:
|
| 267 |
+
return False
|
| 268 |
+
return parameter.kind in (
|
| 269 |
+
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
| 270 |
+
inspect.Parameter.KEYWORD_ONLY,
|
| 271 |
+
)
|
DPTNet_eval/asteroid_test/masknn/norms.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from functools import partial
|
| 2 |
+
import torch
|
| 3 |
+
from torch import nn
|
| 4 |
+
from torch.nn.modules.batchnorm import _BatchNorm
|
| 5 |
+
|
| 6 |
+
EPS = 1e-8
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class _LayerNorm(nn.Module):
|
| 10 |
+
"""Layer Normalization base class."""
|
| 11 |
+
|
| 12 |
+
def __init__(self, channel_size):
|
| 13 |
+
super(_LayerNorm, self).__init__()
|
| 14 |
+
self.channel_size = channel_size
|
| 15 |
+
self.gamma = nn.Parameter(torch.ones(channel_size), requires_grad=True)
|
| 16 |
+
self.beta = nn.Parameter(torch.zeros(channel_size), requires_grad=True)
|
| 17 |
+
|
| 18 |
+
def apply_gain_and_bias(self, normed_x):
|
| 19 |
+
""" Assumes input of size `[batch, chanel, *]`. """
|
| 20 |
+
return (self.gamma * normed_x.transpose(1, -1) + self.beta).transpose(1, -1)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class GlobLN(_LayerNorm):
|
| 24 |
+
"""Global Layer Normalization (globLN)."""
|
| 25 |
+
|
| 26 |
+
def forward(self, x):
|
| 27 |
+
"""Applies forward pass.
|
| 28 |
+
|
| 29 |
+
Works for any input size > 2D.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
x (:class:`torch.Tensor`): Shape `[batch, chan, *]`
|
| 33 |
+
|
| 34 |
+
Returns:
|
| 35 |
+
:class:`torch.Tensor`: gLN_x `[batch, chan, *]`
|
| 36 |
+
"""
|
| 37 |
+
dims = list(range(1, len(x.shape)))
|
| 38 |
+
mean = x.mean(dim=dims, keepdim=True)
|
| 39 |
+
var = torch.pow(x - mean, 2).mean(dim=dims, keepdim=True)
|
| 40 |
+
return self.apply_gain_and_bias((x - mean) / (var + EPS).sqrt())
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class ChanLN(_LayerNorm):
|
| 44 |
+
"""Channel-wise Layer Normalization (chanLN)."""
|
| 45 |
+
|
| 46 |
+
def forward(self, x):
|
| 47 |
+
"""Applies forward pass.
|
| 48 |
+
|
| 49 |
+
Works for any input size > 2D.
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
x (:class:`torch.Tensor`): `[batch, chan, *]`
|
| 53 |
+
|
| 54 |
+
Returns:
|
| 55 |
+
:class:`torch.Tensor`: chanLN_x `[batch, chan, *]`
|
| 56 |
+
"""
|
| 57 |
+
mean = torch.mean(x, dim=1, keepdim=True)
|
| 58 |
+
var = torch.var(x, dim=1, keepdim=True, unbiased=False)
|
| 59 |
+
return self.apply_gain_and_bias((x - mean) / (var + EPS).sqrt())
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class CumLN(_LayerNorm):
|
| 63 |
+
"""Cumulative Global layer normalization(cumLN)."""
|
| 64 |
+
|
| 65 |
+
def forward(self, x):
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
x (:class:`torch.Tensor`): Shape `[batch, channels, length]`
|
| 70 |
+
Returns:
|
| 71 |
+
:class:`torch.Tensor`: cumLN_x `[batch, channels, length]`
|
| 72 |
+
"""
|
| 73 |
+
batch, chan, spec_len = x.size()
|
| 74 |
+
cum_sum = torch.cumsum(x.sum(1, keepdim=True), dim=-1)
|
| 75 |
+
cum_pow_sum = torch.cumsum(x.pow(2).sum(1, keepdim=True), dim=-1)
|
| 76 |
+
cnt = torch.arange(start=chan, end=chan * (spec_len + 1), step=chan, dtype=x.dtype).view(
|
| 77 |
+
1, 1, -1
|
| 78 |
+
)
|
| 79 |
+
cum_mean = cum_sum / cnt
|
| 80 |
+
cum_var = cum_pow_sum - cum_mean.pow(2)
|
| 81 |
+
return self.apply_gain_and_bias((x - cum_mean) / (cum_var + EPS).sqrt())
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class FeatsGlobLN(_LayerNorm):
|
| 85 |
+
"""feature-wise global Layer Normalization (FeatsGlobLN).
|
| 86 |
+
Applies normalization over frames for each channel."""
|
| 87 |
+
|
| 88 |
+
def forward(self, x):
|
| 89 |
+
"""Applies forward pass.
|
| 90 |
+
|
| 91 |
+
Works for any input size > 2D.
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
x (:class:`torch.Tensor`): `[batch, chan, time]`
|
| 95 |
+
|
| 96 |
+
Returns:
|
| 97 |
+
:class:`torch.Tensor`: chanLN_x `[batch, chan, time]`
|
| 98 |
+
"""
|
| 99 |
+
|
| 100 |
+
stop = len(x.size())
|
| 101 |
+
dims = list(range(2, stop))
|
| 102 |
+
|
| 103 |
+
mean = torch.mean(x, dim=dims, keepdim=True)
|
| 104 |
+
var = torch.var(x, dim=dims, keepdim=True, unbiased=False)
|
| 105 |
+
return self.apply_gain_and_bias((x - mean) / (var + EPS).sqrt())
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
class BatchNorm(_BatchNorm):
|
| 109 |
+
"""Wrapper class for pytorch BatchNorm1D and BatchNorm2D"""
|
| 110 |
+
|
| 111 |
+
def _check_input_dim(self, input):
|
| 112 |
+
if input.dim() < 2 or input.dim() > 4:
|
| 113 |
+
raise ValueError("expected 4D or 3D input (got {}D input)".format(input.dim()))
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
# Aliases.
|
| 117 |
+
gLN = GlobLN
|
| 118 |
+
fgLN = FeatsGlobLN
|
| 119 |
+
cLN = ChanLN
|
| 120 |
+
cgLN = CumLN
|
| 121 |
+
bN = BatchNorm
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def register_norm(custom_norm):
|
| 125 |
+
"""Register a custom norm, gettable with `norms.get`.
|
| 126 |
+
|
| 127 |
+
Args:
|
| 128 |
+
custom_norm: Custom norm to register.
|
| 129 |
+
|
| 130 |
+
"""
|
| 131 |
+
if custom_norm.__name__ in globals().keys() or custom_norm.__name__.lower() in globals().keys():
|
| 132 |
+
raise ValueError(f"Norm {custom_norm.__name__} already exists. Choose another name.")
|
| 133 |
+
globals().update({custom_norm.__name__: custom_norm})
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def get(identifier):
|
| 137 |
+
"""Returns a norm class from a string. Returns its input if it
|
| 138 |
+
is callable (already a :class:`._LayerNorm` for example).
|
| 139 |
+
|
| 140 |
+
Args:
|
| 141 |
+
identifier (str or Callable or None): the norm identifier.
|
| 142 |
+
|
| 143 |
+
Returns:
|
| 144 |
+
:class:`._LayerNorm` or None
|
| 145 |
+
"""
|
| 146 |
+
if identifier is None:
|
| 147 |
+
return None
|
| 148 |
+
elif callable(identifier):
|
| 149 |
+
return identifier
|
| 150 |
+
elif isinstance(identifier, str):
|
| 151 |
+
cls = globals().get(identifier)
|
| 152 |
+
if cls is None:
|
| 153 |
+
raise ValueError("Could not interpret normalization identifier: " + str(identifier))
|
| 154 |
+
return cls
|
| 155 |
+
else:
|
| 156 |
+
raise ValueError("Could not interpret normalization identifier: " + str(identifier))
|
DPTNet_eval/asteroid_test/models/__init__.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Models
|
| 2 |
+
# from .conv_tasnet import ConvTasNet
|
| 3 |
+
# from .dccrnet import DCCRNet
|
| 4 |
+
# from .dcunet import DCUNet
|
| 5 |
+
# from .dprnn_tasnet import DPRNNTasNet
|
| 6 |
+
# from .sudormrf import SuDORMRFImprovedNet, SuDORMRFNet
|
| 7 |
+
from .dptnet import DPTNet
|
| 8 |
+
# from .lstm_tasnet import LSTMTasNet
|
| 9 |
+
# from .demask import DeMask
|
| 10 |
+
|
| 11 |
+
# Sharing-related
|
| 12 |
+
# from .publisher import save_publishable, upload_publishable
|
| 13 |
+
|
| 14 |
+
__all__ = [
|
| 15 |
+
# "ConvTasNet",
|
| 16 |
+
# "DPRNNTasNet",
|
| 17 |
+
# "SuDORMRFImprovedNet",
|
| 18 |
+
# "SuDORMRFNet",
|
| 19 |
+
"DPTNet",
|
| 20 |
+
# "LSTMTasNet",
|
| 21 |
+
# "DeMask",
|
| 22 |
+
# "DCUNet",
|
| 23 |
+
# "DCCRNet",
|
| 24 |
+
# "save_publishable",
|
| 25 |
+
# "upload_publishable",
|
| 26 |
+
]
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def register_model(custom_model):
|
| 30 |
+
"""Register a custom model, gettable with `models.get`.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
custom_model: Custom model to register.
|
| 34 |
+
|
| 35 |
+
"""
|
| 36 |
+
if (
|
| 37 |
+
custom_model.__name__ in globals().keys()
|
| 38 |
+
or custom_model.__name__.lower() in globals().keys()
|
| 39 |
+
):
|
| 40 |
+
raise ValueError(f"Model {custom_model.__name__} already exists. Choose another name.")
|
| 41 |
+
globals().update({custom_model.__name__: custom_model})
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def get(identifier):
|
| 45 |
+
"""Returns an model class from a string (case-insensitive).
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
identifier (str): the model name.
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
:class:`torch.nn.Module`
|
| 52 |
+
"""
|
| 53 |
+
if isinstance(identifier, str):
|
| 54 |
+
to_get = {k.lower(): v for k, v in globals().items()}
|
| 55 |
+
cls = to_get.get(identifier.lower())
|
| 56 |
+
if cls is None:
|
| 57 |
+
raise ValueError(f"Could not interpret model name : {str(identifier)}")
|
| 58 |
+
return cls
|
| 59 |
+
raise ValueError(f"Could not interpret model name : {str(identifier)}")
|
DPTNet_eval/asteroid_test/models/base_models.py
ADDED
|
@@ -0,0 +1,351 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import warnings
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
from torch import nn
|
| 7 |
+
|
| 8 |
+
from ..masknn import activations
|
| 9 |
+
from ..utils.torch_utils import pad_x_to_y
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def _unsqueeze_to_3d(x):
|
| 13 |
+
if x.ndim == 1:
|
| 14 |
+
return x.reshape(1, 1, -1)
|
| 15 |
+
elif x.ndim == 2:
|
| 16 |
+
return x.unsqueeze(1)
|
| 17 |
+
else:
|
| 18 |
+
return x
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class BaseModel(nn.Module):
|
| 22 |
+
def __init__(self):
|
| 23 |
+
print("initialize BaseModel")
|
| 24 |
+
super().__init__()
|
| 25 |
+
|
| 26 |
+
def forward(self, *args, **kwargs):
|
| 27 |
+
raise NotImplementedError
|
| 28 |
+
|
| 29 |
+
@torch.no_grad()
|
| 30 |
+
def separate(self, wav, output_dir=None, force_overwrite=False, **kwargs):
|
| 31 |
+
"""Infer separated sources from input waveforms.
|
| 32 |
+
Also supports filenames.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
wav (Union[torch.Tensor, numpy.ndarray, str]): waveform array/tensor.
|
| 36 |
+
Shape: 1D, 2D or 3D tensor, time last.
|
| 37 |
+
output_dir (str): path to save all the wav files. If None,
|
| 38 |
+
estimated sources will be saved next to the original ones.
|
| 39 |
+
force_overwrite (bool): whether to overwrite existing files.
|
| 40 |
+
**kwargs: keyword arguments to be passed to `_separate`.
|
| 41 |
+
|
| 42 |
+
Returns:
|
| 43 |
+
Union[torch.Tensor, numpy.ndarray, None], the estimated sources.
|
| 44 |
+
(batch, n_src, time) or (n_src, time) w/o batch dim.
|
| 45 |
+
|
| 46 |
+
.. note::
|
| 47 |
+
By default, `separate` calls `_separate` which calls `forward`.
|
| 48 |
+
For models whose `forward` doesn't return waveform tensors,
|
| 49 |
+
overwrite `_separate` to return waveform tensors.
|
| 50 |
+
"""
|
| 51 |
+
if isinstance(wav, str):
|
| 52 |
+
self.file_separate(
|
| 53 |
+
wav, output_dir=output_dir, force_overwrite=force_overwrite, **kwargs
|
| 54 |
+
)
|
| 55 |
+
elif isinstance(wav, np.ndarray):
|
| 56 |
+
print("is ndarray")
|
| 57 |
+
# import pdb ; pdb.set_trace()
|
| 58 |
+
return self.numpy_separate(wav, **kwargs)
|
| 59 |
+
elif isinstance(wav, torch.Tensor):
|
| 60 |
+
print("is torch.Tensor")
|
| 61 |
+
return self.torch_separate(wav, **kwargs)
|
| 62 |
+
else:
|
| 63 |
+
raise ValueError(
|
| 64 |
+
f"Only support filenames, numpy arrays and torch tensors, received {type(wav)}"
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
def torch_separate(self, wav: torch.Tensor, **kwargs) -> torch.Tensor:
|
| 68 |
+
""" Core logic of `separate`."""
|
| 69 |
+
# Handle device placement
|
| 70 |
+
input_device = wav.device
|
| 71 |
+
model_device = next(self.parameters()).device
|
| 72 |
+
wav = wav.to(model_device)
|
| 73 |
+
# Forward
|
| 74 |
+
out_wavs = self._separate(wav, **kwargs)
|
| 75 |
+
|
| 76 |
+
# FIXME: for now this is the best we can do.
|
| 77 |
+
out_wavs *= wav.abs().sum() / (out_wavs.abs().sum())
|
| 78 |
+
|
| 79 |
+
# Back to input device (and numpy if necessary)
|
| 80 |
+
out_wavs = out_wavs.to(input_device)
|
| 81 |
+
return out_wavs
|
| 82 |
+
|
| 83 |
+
def numpy_separate(self, wav: np.ndarray, **kwargs) -> np.ndarray:
|
| 84 |
+
""" Numpy interface to `separate`."""
|
| 85 |
+
wav = torch.from_numpy(wav)
|
| 86 |
+
out_wav = self.torch_separate(wav, **kwargs)
|
| 87 |
+
out_wav = out_wav.data.numpy()
|
| 88 |
+
return out_wav
|
| 89 |
+
|
| 90 |
+
def file_separate(
|
| 91 |
+
self, filename: str, output_dir=None, force_overwrite=False, **kwargs
|
| 92 |
+
) -> None:
|
| 93 |
+
""" Filename interface to `separate`."""
|
| 94 |
+
import soundfile as sf
|
| 95 |
+
|
| 96 |
+
wav, fs = sf.read(filename, dtype="float32", always_2d=True)
|
| 97 |
+
# FIXME: support only single-channel files for now.
|
| 98 |
+
to_save = self.numpy_separate(wav[:, 0], **kwargs)
|
| 99 |
+
|
| 100 |
+
# Save wav files to filename_est1.wav etc...
|
| 101 |
+
for src_idx, est_src in enumerate(to_save):
|
| 102 |
+
base = ".".join(filename.split(".")[:-1])
|
| 103 |
+
save_name = base + "_est{}.".format(src_idx + 1) + filename.split(".")[-1]
|
| 104 |
+
if os.path.isfile(save_name) and not force_overwrite:
|
| 105 |
+
warnings.warn(
|
| 106 |
+
f"File {save_name} already exists, pass `force_overwrite=True` to overwrite it",
|
| 107 |
+
UserWarning,
|
| 108 |
+
)
|
| 109 |
+
return
|
| 110 |
+
if output_dir is not None:
|
| 111 |
+
save_name = os.path.join(output_dir, save_name.split("/")[-1])
|
| 112 |
+
sf.write(save_name, est_src, fs)
|
| 113 |
+
|
| 114 |
+
def _separate(self, wav, *args, **kwargs):
|
| 115 |
+
"""Hidden separation method
|
| 116 |
+
|
| 117 |
+
Args:
|
| 118 |
+
wav (Union[torch.Tensor, numpy.ndarray, str]): waveform array/tensor.
|
| 119 |
+
Shape: 1D, 2D or 3D tensor, time last.
|
| 120 |
+
|
| 121 |
+
Returns:
|
| 122 |
+
The output of self(wav, *args, **kwargs).
|
| 123 |
+
"""
|
| 124 |
+
return self(wav, *args, **kwargs)
|
| 125 |
+
|
| 126 |
+
@classmethod
|
| 127 |
+
def from_pretrained(cls, pretrained_model_conf_or_path, *args, **kwargs):
|
| 128 |
+
"""Instantiate separation model from a model config (file or dict).
|
| 129 |
+
|
| 130 |
+
Args:
|
| 131 |
+
pretrained_model_conf_or_path (Union[dict, str]): model conf as
|
| 132 |
+
returned by `serialize`, or path to it. Need to contain
|
| 133 |
+
`model_args` and `state_dict` keys.
|
| 134 |
+
*args: Positional arguments to be passed to the model.
|
| 135 |
+
**kwargs: Keyword arguments to be passed to the model.
|
| 136 |
+
They overwrite the ones in the model package.
|
| 137 |
+
|
| 138 |
+
Returns:
|
| 139 |
+
nn.Module corresponding to the pretrained model conf/URL.
|
| 140 |
+
|
| 141 |
+
Raises:
|
| 142 |
+
ValueError if the input config file doesn't contain the keys
|
| 143 |
+
`model_name`, `model_args` or `state_dict`.
|
| 144 |
+
"""
|
| 145 |
+
from . import get # Avoid circular imports
|
| 146 |
+
|
| 147 |
+
if isinstance(pretrained_model_conf_or_path, str):
|
| 148 |
+
# cached_model = self.cached_download(pretrained_model_conf_or_path)
|
| 149 |
+
if os.path.isfile(pretrained_model_conf_or_path):
|
| 150 |
+
cached_model = pretrained_model_conf_or_path
|
| 151 |
+
else:
|
| 152 |
+
raise ValueError(
|
| 153 |
+
"Model {} is not a file or doesn't exist.".format(pretrained_model_conf_or_path)
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
conf = torch.load(cached_model, map_location="cpu")
|
| 157 |
+
else:
|
| 158 |
+
conf = pretrained_model_conf_or_path
|
| 159 |
+
|
| 160 |
+
if "model_name" not in conf.keys():
|
| 161 |
+
raise ValueError(
|
| 162 |
+
"Expected config dictionary to have field "
|
| 163 |
+
"model_name`. Found only: {}".format(conf.keys())
|
| 164 |
+
)
|
| 165 |
+
if "state_dict" not in conf.keys():
|
| 166 |
+
raise ValueError(
|
| 167 |
+
"Expected config dictionary to have field "
|
| 168 |
+
"state_dict`. Found only: {}".format(conf.keys())
|
| 169 |
+
)
|
| 170 |
+
if "model_args" not in conf.keys():
|
| 171 |
+
raise ValueError(
|
| 172 |
+
"Expected config dictionary to have field "
|
| 173 |
+
"model_args`. Found only: {}".format(conf.keys())
|
| 174 |
+
)
|
| 175 |
+
conf["model_args"].update(kwargs) # kwargs overwrite config.
|
| 176 |
+
# Attempt to find the model and instantiate it.
|
| 177 |
+
try:
|
| 178 |
+
model_class = get(conf["model_name"])
|
| 179 |
+
except ValueError: # Couldn't get the model, maybe custom.
|
| 180 |
+
model = cls(*args, **conf["model_args"]) # Child class.
|
| 181 |
+
else:
|
| 182 |
+
model = model_class(*args, **conf["model_args"])
|
| 183 |
+
model.load_state_dict(conf["state_dict"])
|
| 184 |
+
return model
|
| 185 |
+
|
| 186 |
+
def serialize(self):
|
| 187 |
+
"""Serialize model and output dictionary.
|
| 188 |
+
|
| 189 |
+
Returns:
|
| 190 |
+
dict, serialized model with keys `model_args` and `state_dict`.
|
| 191 |
+
"""
|
| 192 |
+
import pytorch_lightning as pl # Not used in torch.hub
|
| 193 |
+
|
| 194 |
+
from .. import __version__ as asteroid_version # Avoid circular imports
|
| 195 |
+
|
| 196 |
+
model_conf = dict(
|
| 197 |
+
model_name=self.__class__.__name__,
|
| 198 |
+
state_dict=self.get_state_dict(),
|
| 199 |
+
model_args=self.get_model_args(),
|
| 200 |
+
)
|
| 201 |
+
# Additional infos
|
| 202 |
+
infos = dict()
|
| 203 |
+
infos["software_versions"] = dict(
|
| 204 |
+
torch_version=torch.__version__,
|
| 205 |
+
pytorch_lightning_version=pl.__version__,
|
| 206 |
+
asteroid_version=asteroid_version,
|
| 207 |
+
)
|
| 208 |
+
model_conf["infos"] = infos
|
| 209 |
+
return model_conf
|
| 210 |
+
|
| 211 |
+
def get_state_dict(self):
|
| 212 |
+
""" In case the state dict needs to be modified before sharing the model."""
|
| 213 |
+
return self.state_dict()
|
| 214 |
+
|
| 215 |
+
def get_model_args(self):
|
| 216 |
+
raise NotImplementedError
|
| 217 |
+
|
| 218 |
+
def cached_download(self, filename_or_url):
|
| 219 |
+
if os.path.isfile(filename_or_url):
|
| 220 |
+
print("is file")
|
| 221 |
+
return filename_or_url
|
| 222 |
+
else:
|
| 223 |
+
print("Model {} is not a file or doesn't exist.".format(filename_or_url))
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
class BaseEncoderMaskerDecoder(BaseModel):
|
| 227 |
+
"""Base class for encoder-masker-decoder separation models.
|
| 228 |
+
|
| 229 |
+
Args:
|
| 230 |
+
encoder (Encoder): Encoder instance.
|
| 231 |
+
masker (nn.Module): masker network.
|
| 232 |
+
decoder (Decoder): Decoder instance.
|
| 233 |
+
encoder_activation (Optional[str], optional): Activation to apply after encoder.
|
| 234 |
+
See ``asteroid.masknn.activations`` for valid values.
|
| 235 |
+
"""
|
| 236 |
+
|
| 237 |
+
def __init__(self, encoder, masker, decoder, encoder_activation=None):
|
| 238 |
+
super().__init__()
|
| 239 |
+
self.encoder = encoder
|
| 240 |
+
self.masker = masker
|
| 241 |
+
self.decoder = decoder
|
| 242 |
+
|
| 243 |
+
self.encoder_activation = encoder_activation
|
| 244 |
+
self.enc_activation = activations.get(encoder_activation or "linear")()
|
| 245 |
+
|
| 246 |
+
def forward(self, wav):
|
| 247 |
+
"""Enc/Mask/Dec model forward
|
| 248 |
+
|
| 249 |
+
Args:
|
| 250 |
+
wav (torch.Tensor): waveform tensor. 1D, 2D or 3D tensor, time last.
|
| 251 |
+
|
| 252 |
+
Returns:
|
| 253 |
+
torch.Tensor, of shape (batch, n_src, time) or (n_src, time).
|
| 254 |
+
"""
|
| 255 |
+
# Handle 1D, 2D or n-D inputs
|
| 256 |
+
was_one_d = wav.ndim == 1
|
| 257 |
+
# Reshape to (batch, n_mix, time)
|
| 258 |
+
wav = _unsqueeze_to_3d(wav)
|
| 259 |
+
|
| 260 |
+
# Real forward
|
| 261 |
+
tf_rep = self.encoder(wav)
|
| 262 |
+
tf_rep = self.postprocess_encoded(tf_rep)
|
| 263 |
+
tf_rep = self.enc_activation(tf_rep)
|
| 264 |
+
|
| 265 |
+
est_masks = self.masker(tf_rep)
|
| 266 |
+
est_masks = self.postprocess_masks(est_masks)
|
| 267 |
+
|
| 268 |
+
masked_tf_rep = est_masks * tf_rep.unsqueeze(1)
|
| 269 |
+
masked_tf_rep = self.postprocess_masked(masked_tf_rep)
|
| 270 |
+
|
| 271 |
+
decoded = self.decoder(masked_tf_rep)
|
| 272 |
+
decoded = self.postprocess_decoded(decoded)
|
| 273 |
+
|
| 274 |
+
reconstructed = pad_x_to_y(decoded, wav)
|
| 275 |
+
if was_one_d:
|
| 276 |
+
return reconstructed.squeeze(0)
|
| 277 |
+
else:
|
| 278 |
+
return reconstructed
|
| 279 |
+
|
| 280 |
+
def postprocess_encoded(self, tf_rep):
|
| 281 |
+
"""Hook to perform transformations on the encoded, time-frequency domain
|
| 282 |
+
representation (output of the encoder) before encoder activation is applied.
|
| 283 |
+
|
| 284 |
+
Args:
|
| 285 |
+
tf_rep (Tensor of shape (batch, freq, time)):
|
| 286 |
+
Output of the encoder, before encoder activation is applied.
|
| 287 |
+
|
| 288 |
+
Return:
|
| 289 |
+
Transformed `tf_rep`
|
| 290 |
+
"""
|
| 291 |
+
return tf_rep
|
| 292 |
+
|
| 293 |
+
def postprocess_masks(self, masks):
|
| 294 |
+
"""Hook to perform transformations on the masks (output of the masker) before
|
| 295 |
+
masks are applied.
|
| 296 |
+
|
| 297 |
+
Args:
|
| 298 |
+
masks (Tensor of shape (batch, n_src, freq, time)):
|
| 299 |
+
Output of the masker
|
| 300 |
+
|
| 301 |
+
Return:
|
| 302 |
+
Transformed `masks`
|
| 303 |
+
"""
|
| 304 |
+
return masks
|
| 305 |
+
|
| 306 |
+
def postprocess_masked(self, masked_tf_rep):
|
| 307 |
+
"""Hook to perform transformations on the masked time-frequency domain
|
| 308 |
+
representation (result of masking in the time-frequency domain) before decoding.
|
| 309 |
+
|
| 310 |
+
Args:
|
| 311 |
+
masked_tf_rep (Tensor of shape (batch, n_src, freq, time)):
|
| 312 |
+
Masked time-frequency representation, before decoding.
|
| 313 |
+
|
| 314 |
+
Return:
|
| 315 |
+
Transformed `masked_tf_rep`
|
| 316 |
+
"""
|
| 317 |
+
return masked_tf_rep
|
| 318 |
+
|
| 319 |
+
def postprocess_decoded(self, decoded):
|
| 320 |
+
"""Hook to perform transformations on the decoded, time domain representation
|
| 321 |
+
(output of the decoder) before original shape reconstruction.
|
| 322 |
+
|
| 323 |
+
Args:
|
| 324 |
+
decoded (Tensor of shape (batch, n_src, time)):
|
| 325 |
+
Output of the decoder, before original shape reconstruction.
|
| 326 |
+
|
| 327 |
+
Return:
|
| 328 |
+
Transformed `decoded`
|
| 329 |
+
"""
|
| 330 |
+
return decoded
|
| 331 |
+
|
| 332 |
+
def get_model_args(self):
|
| 333 |
+
""" Arguments needed to re-instantiate the model. """
|
| 334 |
+
fb_config = self.encoder.filterbank.get_config()
|
| 335 |
+
masknet_config = self.masker.get_config()
|
| 336 |
+
# Assert both dict are disjoint
|
| 337 |
+
if not all(k not in fb_config for k in masknet_config):
|
| 338 |
+
raise AssertionError(
|
| 339 |
+
"Filterbank and Mask network config share" "common keys. Merging them is not safe."
|
| 340 |
+
)
|
| 341 |
+
# Merge all args under model_args.
|
| 342 |
+
model_args = {
|
| 343 |
+
**fb_config,
|
| 344 |
+
**masknet_config,
|
| 345 |
+
"encoder_activation": self.encoder_activation,
|
| 346 |
+
}
|
| 347 |
+
return model_args
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
# Backwards compatibility
|
| 351 |
+
BaseTasNet = BaseEncoderMaskerDecoder
|
DPTNet_eval/asteroid_test/models/dptnet.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ..filterbanks import make_enc_dec
|
| 2 |
+
from ..masknn import DPTransformer
|
| 3 |
+
from .base_models import BaseEncoderMaskerDecoder
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class DPTNet(BaseEncoderMaskerDecoder):
|
| 7 |
+
"""DPTNet separation model, as described in [1].
|
| 8 |
+
|
| 9 |
+
Args:
|
| 10 |
+
n_src (int): Number of masks to estimate.
|
| 11 |
+
out_chan (int or None): Number of bins in the estimated masks.
|
| 12 |
+
Defaults to `in_chan`.
|
| 13 |
+
bn_chan (int): Number of channels after the bottleneck.
|
| 14 |
+
Defaults to 128.
|
| 15 |
+
hid_size (int): Number of neurons in the RNNs cell state.
|
| 16 |
+
Defaults to 128.
|
| 17 |
+
chunk_size (int): window size of overlap and add processing.
|
| 18 |
+
Defaults to 100.
|
| 19 |
+
hop_size (int or None): hop size (stride) of overlap and add processing.
|
| 20 |
+
Default to `chunk_size // 2` (50% overlap).
|
| 21 |
+
n_repeats (int): Number of repeats. Defaults to 6.
|
| 22 |
+
norm_type (str, optional): Type of normalization to use. To choose from
|
| 23 |
+
|
| 24 |
+
- ``'gLN'``: global Layernorm
|
| 25 |
+
- ``'cLN'``: channelwise Layernorm
|
| 26 |
+
mask_act (str, optional): Which non-linear function to generate mask.
|
| 27 |
+
bidirectional (bool, optional): True for bidirectional Inter-Chunk RNN
|
| 28 |
+
(Intra-Chunk is always bidirectional).
|
| 29 |
+
rnn_type (str, optional): Type of RNN used. Choose between ``'RNN'``,
|
| 30 |
+
``'LSTM'`` and ``'GRU'``.
|
| 31 |
+
num_layers (int, optional): Number of layers in each RNN.
|
| 32 |
+
dropout (float, optional): Dropout ratio, must be in [0,1].
|
| 33 |
+
in_chan (int, optional): Number of input channels, should be equal to
|
| 34 |
+
n_filters.
|
| 35 |
+
fb_name (str, className): Filterbank family from which to make encoder
|
| 36 |
+
and decoder. To choose among [``'free'``, ``'analytic_free'``,
|
| 37 |
+
``'param_sinc'``, ``'stft'``].
|
| 38 |
+
n_filters (int): Number of filters / Input dimension of the masker net.
|
| 39 |
+
kernel_size (int): Length of the filters.
|
| 40 |
+
stride (int, optional): Stride of the convolution.
|
| 41 |
+
If None (default), set to ``kernel_size // 2``.
|
| 42 |
+
**fb_kwargs (dict): Additional kwards to pass to the filterbank
|
| 43 |
+
creation.
|
| 44 |
+
|
| 45 |
+
References:
|
| 46 |
+
[1]: Jingjing Chen et al. "Dual-Path Transformer Network: Direct
|
| 47 |
+
Context-Aware Modeling for End-to-End Monaural Speech Separation"
|
| 48 |
+
Interspeech 2020.
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
def __init__(
|
| 52 |
+
self,
|
| 53 |
+
n_src,
|
| 54 |
+
ff_hid=256,
|
| 55 |
+
chunk_size=100,
|
| 56 |
+
hop_size=None,
|
| 57 |
+
n_repeats=6,
|
| 58 |
+
norm_type="gLN",
|
| 59 |
+
ff_activation="relu",
|
| 60 |
+
encoder_activation="relu",
|
| 61 |
+
mask_act="relu",
|
| 62 |
+
bidirectional=True,
|
| 63 |
+
dropout=0,
|
| 64 |
+
in_chan=None,
|
| 65 |
+
fb_name="free",
|
| 66 |
+
kernel_size=16,
|
| 67 |
+
n_filters=64,
|
| 68 |
+
stride=8,
|
| 69 |
+
**fb_kwargs,
|
| 70 |
+
):
|
| 71 |
+
encoder, decoder = make_enc_dec(
|
| 72 |
+
fb_name, kernel_size=kernel_size, n_filters=n_filters, stride=stride, **fb_kwargs
|
| 73 |
+
)
|
| 74 |
+
n_feats = encoder.n_feats_out
|
| 75 |
+
if in_chan is not None:
|
| 76 |
+
assert in_chan == n_feats, (
|
| 77 |
+
"Number of filterbank output channels"
|
| 78 |
+
" and number of input channels should "
|
| 79 |
+
"be the same. Received "
|
| 80 |
+
f"{n_feats} and {in_chan}"
|
| 81 |
+
)
|
| 82 |
+
# Update in_chan
|
| 83 |
+
masker = DPTransformer(
|
| 84 |
+
n_feats,
|
| 85 |
+
n_src,
|
| 86 |
+
ff_hid=ff_hid,
|
| 87 |
+
ff_activation=ff_activation,
|
| 88 |
+
chunk_size=chunk_size,
|
| 89 |
+
hop_size=hop_size,
|
| 90 |
+
n_repeats=n_repeats,
|
| 91 |
+
norm_type=norm_type,
|
| 92 |
+
mask_act=mask_act,
|
| 93 |
+
bidirectional=bidirectional,
|
| 94 |
+
dropout=dropout,
|
| 95 |
+
)
|
| 96 |
+
super().__init__(encoder, masker, decoder, encoder_activation=encoder_activation)
|
DPTNet_eval/asteroid_test/utils/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .torch_utils import tensors_to_device, to_cuda
|
| 2 |
+
|
| 3 |
+
# The functions above were all in asteroid/utils.py before refactoring into
|
| 4 |
+
# asteroid/utils/*_utils.py files. They are imported for backward compatibility.
|
| 5 |
+
|
| 6 |
+
__all__ = [
|
| 7 |
+
"tensors_to_device",
|
| 8 |
+
"to_cuda",
|
| 9 |
+
]
|
DPTNet_eval/asteroid_test/utils/torch_utils.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
from collections import OrderedDict
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def to_cuda(tensors): # pragma: no cover (No CUDA on travis)
|
| 7 |
+
"""Transfer tensor, dict or list of tensors to GPU.
|
| 8 |
+
|
| 9 |
+
Args:
|
| 10 |
+
tensors (:class:`torch.Tensor`, list or dict): May be a single, a
|
| 11 |
+
list or a dictionary of tensors.
|
| 12 |
+
|
| 13 |
+
Returns:
|
| 14 |
+
:class:`torch.Tensor`:
|
| 15 |
+
Same as input but transferred to cuda. Goes through lists and dicts
|
| 16 |
+
and transfers the torch.Tensor to cuda. Leaves the rest untouched.
|
| 17 |
+
"""
|
| 18 |
+
if isinstance(tensors, torch.Tensor):
|
| 19 |
+
return tensors.cuda()
|
| 20 |
+
if isinstance(tensors, list):
|
| 21 |
+
return [to_cuda(tens) for tens in tensors]
|
| 22 |
+
if isinstance(tensors, dict):
|
| 23 |
+
for key in tensors.keys():
|
| 24 |
+
tensors[key] = to_cuda(tensors[key])
|
| 25 |
+
return tensors
|
| 26 |
+
raise TypeError(
|
| 27 |
+
"tensors must be a tensor or a list or dict of tensors. "
|
| 28 |
+
" Got tensors of type {}".format(type(tensors))
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def tensors_to_device(tensors, device):
|
| 33 |
+
"""Transfer tensor, dict or list of tensors to device.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
tensors (:class:`torch.Tensor`): May be a single, a list or a
|
| 37 |
+
dictionary of tensors.
|
| 38 |
+
device (:class: `torch.device`): the device where to place the tensors.
|
| 39 |
+
|
| 40 |
+
Returns:
|
| 41 |
+
Union [:class:`torch.Tensor`, list, tuple, dict]:
|
| 42 |
+
Same as input but transferred to device.
|
| 43 |
+
Goes through lists and dicts and transfers the torch.Tensor to
|
| 44 |
+
device. Leaves the rest untouched.
|
| 45 |
+
"""
|
| 46 |
+
if isinstance(tensors, torch.Tensor):
|
| 47 |
+
return tensors.to(device)
|
| 48 |
+
elif isinstance(tensors, (list, tuple)):
|
| 49 |
+
return [tensors_to_device(tens, device) for tens in tensors]
|
| 50 |
+
elif isinstance(tensors, dict):
|
| 51 |
+
for key in tensors.keys():
|
| 52 |
+
tensors[key] = tensors_to_device(tensors[key], device)
|
| 53 |
+
return tensors
|
| 54 |
+
else:
|
| 55 |
+
return tensors
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def pad_x_to_y(x, y, axis=-1):
|
| 59 |
+
"""Pad first argument to have same size as second argument
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
x (torch.Tensor): Tensor to be padded.
|
| 63 |
+
y (torch.Tensor): Tensor to pad x to.
|
| 64 |
+
axis (int): Axis to pad on.
|
| 65 |
+
|
| 66 |
+
Returns:
|
| 67 |
+
torch.Tensor, x padded to match y's shape.
|
| 68 |
+
"""
|
| 69 |
+
if axis != -1:
|
| 70 |
+
raise NotImplementedError
|
| 71 |
+
inp_len = y.size(axis)
|
| 72 |
+
output_len = x.size(axis)
|
| 73 |
+
return nn.functional.pad(x, [0, inp_len - output_len])
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def load_state_dict_in(state_dict, model):
|
| 77 |
+
"""Strictly loads state_dict in model, or the next submodel.
|
| 78 |
+
Useful to load standalone model after training it with System.
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
state_dict (OrderedDict): the state_dict to load.
|
| 82 |
+
model (torch.nn.Module): the model to load it into
|
| 83 |
+
|
| 84 |
+
Returns:
|
| 85 |
+
torch.nn.Module: model with loaded weights.
|
| 86 |
+
|
| 87 |
+
# .. note:: Keys in a state_dict look like object1.object2.layer_name.weight.etc
|
| 88 |
+
We first try to load the model in the classic way.
|
| 89 |
+
If this fail we removes the first left part of the key to obtain
|
| 90 |
+
object2.layer_name.weight.etc.
|
| 91 |
+
Blindly loading with strictly=False should be done with some logging
|
| 92 |
+
of the missing keys in the state_dict and the model.
|
| 93 |
+
|
| 94 |
+
"""
|
| 95 |
+
try:
|
| 96 |
+
# This can fail if the model was included into a bigger nn.Module
|
| 97 |
+
# object. For example, into System.
|
| 98 |
+
model.load_state_dict(state_dict, strict=True)
|
| 99 |
+
except RuntimeError:
|
| 100 |
+
# keys look like object1.object2.layer_name.weight.etc
|
| 101 |
+
# The following will remove the first left part of the key to obtain
|
| 102 |
+
# object2.layer_name.weight.etc.
|
| 103 |
+
# Blindly loading with strictly=False should be done with some
|
| 104 |
+
# new_state_dict of the missing keys in the state_dict and the model.
|
| 105 |
+
new_state_dict = OrderedDict()
|
| 106 |
+
for k, v in state_dict.items():
|
| 107 |
+
new_k = k[k.find(".") + 1 :]
|
| 108 |
+
new_state_dict[new_k] = v
|
| 109 |
+
model.load_state_dict(new_state_dict, strict=True)
|
| 110 |
+
return model
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def are_models_equal(model1, model2):
|
| 114 |
+
"""Check for weights equality between models.
|
| 115 |
+
|
| 116 |
+
Args:
|
| 117 |
+
model1 (nn.Module): model instance to be compared.
|
| 118 |
+
model2 (nn.Module): second model instance to be compared.
|
| 119 |
+
|
| 120 |
+
Returns:
|
| 121 |
+
bool: Whether all model weights are equal.
|
| 122 |
+
"""
|
| 123 |
+
for p1, p2 in zip(model1.parameters(), model2.parameters()):
|
| 124 |
+
if p1.data.ne(p2.data).sum() > 0:
|
| 125 |
+
return False
|
| 126 |
+
return True
|