mattricesound commited on
Commit
7d6f241
1 Parent(s): e4fc05d

Init dcunet and dptnet

Browse files
cfg/model/audio_diffusion.yaml CHANGED
@@ -1,6 +1,6 @@
1
  # @package _global_
2
- model:
3
- _target_: remfx.models.RemFXModel
4
  lr: 1e-4
5
  lr_beta1: 0.95
6
  lr_beta2: 0.999
 
1
  # @package _global_
2
+ model:
3
+ _target_: remfx.models.RemFx
4
  lr: 1e-4
5
  lr_beta1: 0.95
6
  lr_beta2: 0.999
cfg/model/classifier.yaml CHANGED
@@ -5,7 +5,7 @@ model:
5
  lr_weight_decay: 1e-3
6
  sample_rate: ${sample_rate}
7
  network:
8
- _target_: remfx.models.Cnn14
9
  num_classes: ${num_classes}
10
  n_fft: 4096
11
  hop_length: 512
 
5
  lr_weight_decay: 1e-3
6
  sample_rate: ${sample_rate}
7
  network:
8
+ _target_: remfx.cnn14.Cnn14
9
  num_classes: ${num_classes}
10
  n_fft: 4096
11
  hop_length: 512
cfg/model/dcunet.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+ model:
3
+ _target_: remfx.models.RemFx
4
+ lr: 1e-4
5
+ lr_beta1: 0.95
6
+ lr_beta2: 0.999
7
+ lr_eps: 1e-6
8
+ lr_weight_decay: 1e-3
9
+ sample_rate: ${sample_rate}
10
+ network:
11
+ _target_: remfx.models.DCUNetModel
12
+ spec_dim: 256 + 1
13
+ hidden_dim: 768
14
+ filter_len: 512
15
+ hop_len: 64
16
+ block_layers: 4
17
+ layers: 4
18
+ kernel_size: 3
19
+ refine_layers: 1
20
+ is_mask: True
21
+ norm: 'ins'
22
+ act: 'comp'
cfg/model/demucs.yaml CHANGED
@@ -1,6 +1,6 @@
1
  # @package _global_
2
  model:
3
- _target_: remfx.models.RemFXModel
4
  lr: 1e-4
5
  lr_beta1: 0.95
6
  lr_beta2: 0.999
 
1
  # @package _global_
2
  model:
3
+ _target_: remfx.models.RemFx
4
  lr: 1e-4
5
  lr_beta1: 0.95
6
  lr_beta2: 0.999
cfg/model/dptnet.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+ model:
3
+ _target_: remfx.models.RemFx
4
+ lr: 1e-4
5
+ lr_beta1: 0.95
6
+ lr_beta2: 0.999
7
+ lr_eps: 1e-6
8
+ lr_weight_decay: 1e-3
9
+ sample_rate: ${sample_rate}
10
+ network:
11
+ _target_: remfx.models.DPTNetModel
12
+ enc_dim: 256
13
+ feature_dim: 64
14
+ hidden_dim: 128
15
+ layer: 6
16
+ segment_size: 250
17
+ nspk: 1
18
+ win_len: 2
cfg/model/umx.yaml CHANGED
@@ -1,6 +1,6 @@
1
  # @package _global_
2
- model:
3
- _target_: remfx.models.RemFXModel
4
  lr: 1e-4
5
  lr_beta1: 0.95
6
  lr_beta2: 0.999
 
1
  # @package _global_
2
+ model:
3
+ _target_: remfx.models.RemFx
4
  lr: 1e-4
5
  lr_beta1: 0.95
6
  lr_beta2: 0.999
remfx/cnn14.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchaudio
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from utils import init_bn, init_layer
6
+
7
+ # adapted from https://github.com/qiuqiangkong/audioset_tagging_cnn/blob/master/pytorch/models.py
8
+
9
+
10
+ class Cnn14(nn.Module):
11
+ def __init__(
12
+ self,
13
+ num_classes: int,
14
+ sample_rate: float,
15
+ n_fft: int = 2048,
16
+ hop_length: int = 512,
17
+ n_mels: int = 128,
18
+ ):
19
+ super().__init__()
20
+ self.num_classes = num_classes
21
+ self.n_fft = n_fft
22
+ self.hop_length = hop_length
23
+
24
+ window = torch.hann_window(n_fft)
25
+ self.register_buffer("window", window)
26
+
27
+ self.melspec = torchaudio.transforms.MelSpectrogram(
28
+ sample_rate,
29
+ n_fft,
30
+ hop_length=hop_length,
31
+ n_mels=n_mels,
32
+ )
33
+
34
+ self.bn0 = nn.BatchNorm2d(n_mels)
35
+
36
+ self.conv_block1 = ConvBlock(in_channels=1, out_channels=64)
37
+ self.conv_block2 = ConvBlock(in_channels=64, out_channels=128)
38
+ self.conv_block3 = ConvBlock(in_channels=128, out_channels=256)
39
+ self.conv_block4 = ConvBlock(in_channels=256, out_channels=512)
40
+ self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024)
41
+ self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048)
42
+
43
+ self.fc1 = nn.Linear(2048, 2048, bias=True)
44
+ self.fc_audioset = nn.Linear(2048, num_classes, bias=True)
45
+
46
+ self.init_weight()
47
+
48
+ def init_weight(self):
49
+ init_bn(self.bn0)
50
+ init_layer(self.fc1)
51
+ init_layer(self.fc_audioset)
52
+
53
+ def forward(self, x: torch.Tensor):
54
+ """
55
+ Input: (batch_size, data_length)"""
56
+
57
+ x = self.melspec(x)
58
+ x = x.permute(0, 2, 1, 3)
59
+ x = self.bn0(x)
60
+ x = x.permute(0, 2, 1, 3)
61
+
62
+ if self.training:
63
+ pass
64
+ # x = self.spec_augmenter(x)
65
+
66
+ x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg")
67
+ x = F.dropout(x, p=0.2, training=self.training)
68
+ x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg")
69
+ x = F.dropout(x, p=0.2, training=self.training)
70
+ x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg")
71
+ x = F.dropout(x, p=0.2, training=self.training)
72
+ x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg")
73
+ x = F.dropout(x, p=0.2, training=self.training)
74
+ x = self.conv_block5(x, pool_size=(2, 2), pool_type="avg")
75
+ x = F.dropout(x, p=0.2, training=self.training)
76
+ x = self.conv_block6(x, pool_size=(1, 1), pool_type="avg")
77
+ x = F.dropout(x, p=0.2, training=self.training)
78
+ x = torch.mean(x, dim=3)
79
+
80
+ (x1, _) = torch.max(x, dim=2)
81
+ x2 = torch.mean(x, dim=2)
82
+ x = x1 + x2
83
+ x = F.dropout(x, p=0.5, training=self.training)
84
+ x = F.relu_(self.fc1(x))
85
+ clipwise_output = self.fc_audioset(x)
86
+
87
+ return clipwise_output
88
+
89
+
90
+ class ConvBlock(nn.Module):
91
+ def __init__(self, in_channels, out_channels):
92
+ super(ConvBlock, self).__init__()
93
+
94
+ self.conv1 = nn.Conv2d(
95
+ in_channels=in_channels,
96
+ out_channels=out_channels,
97
+ kernel_size=(3, 3),
98
+ stride=(1, 1),
99
+ padding=(1, 1),
100
+ bias=False,
101
+ )
102
+
103
+ self.conv2 = nn.Conv2d(
104
+ in_channels=out_channels,
105
+ out_channels=out_channels,
106
+ kernel_size=(3, 3),
107
+ stride=(1, 1),
108
+ padding=(1, 1),
109
+ bias=False,
110
+ )
111
+
112
+ self.bn1 = nn.BatchNorm2d(out_channels)
113
+ self.bn2 = nn.BatchNorm2d(out_channels)
114
+
115
+ self.init_weight()
116
+
117
+ def init_weight(self):
118
+ init_layer(self.conv1)
119
+ init_layer(self.conv2)
120
+ init_bn(self.bn1)
121
+ init_bn(self.bn2)
122
+
123
+ def forward(self, input, pool_size=(2, 2), pool_type="avg"):
124
+ x = input
125
+ x = F.relu_(self.bn1(self.conv1(x)))
126
+ x = F.relu_(self.bn2(self.conv2(x)))
127
+ if pool_type == "max":
128
+ x = F.max_pool2d(x, kernel_size=pool_size)
129
+ elif pool_type == "avg":
130
+ x = F.avg_pool2d(x, kernel_size=pool_size)
131
+ elif pool_type == "avg+max":
132
+ x1 = F.avg_pool2d(x, kernel_size=pool_size)
133
+ x2 = F.max_pool2d(x, kernel_size=pool_size)
134
+ x = x1 + x2
135
+ else:
136
+ raise Exception("Incorrect argument!")
137
+
138
+ return x
remfx/datasets.py CHANGED
@@ -250,6 +250,7 @@ class VocalSet(Dataset):
250
  # Normalize
251
  normalized_dry = self.normalize(dry)
252
  normalized_wet = self.normalize(wet)
 
253
  return normalized_dry, normalized_wet, dry_labels_tensor, wet_labels_tensor
254
 
255
 
 
250
  # Normalize
251
  normalized_dry = self.normalize(dry)
252
  normalized_wet = self.normalize(wet)
253
+
254
  return normalized_dry, normalized_wet, dry_labels_tensor, wet_labels_tensor
255
 
256
 
remfx/dcunet.py ADDED
@@ -0,0 +1,649 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/AppleHolic/source_separation/tree/master/source_separation
2
+
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import numpy as np
8
+ from utils import single, concat_complex
9
+ from torch.nn.init import calculate_gain
10
+ from typing import Tuple
11
+ from scipy.signal import get_window
12
+ from librosa.util import pad_center
13
+
14
+
15
+ class ComplexConvBlock(nn.Module):
16
+ """
17
+ Convolution block
18
+ """
19
+
20
+ def __init__(
21
+ self,
22
+ in_channels: int,
23
+ out_channels: int,
24
+ kernel_size: int,
25
+ padding: int = 0,
26
+ layers: int = 4,
27
+ bn_func=nn.BatchNorm1d,
28
+ act_func=nn.LeakyReLU,
29
+ skip_res: bool = False,
30
+ ):
31
+ super().__init__()
32
+ # modules
33
+ self.blocks = nn.ModuleList()
34
+ self.skip_res = skip_res
35
+
36
+ for idx in range(layers):
37
+ in_ = in_channels if idx == 0 else out_channels
38
+ self.blocks.append(
39
+ nn.Sequential(
40
+ *[
41
+ bn_func(in_),
42
+ act_func(),
43
+ ComplexConv1d(in_, out_channels, kernel_size, padding=padding),
44
+ ]
45
+ )
46
+ )
47
+
48
+ def forward(self, x: torch.tensor) -> torch.tensor:
49
+ temp = x
50
+ for idx, block in enumerate(self.blocks):
51
+ x = block(x)
52
+
53
+ if temp.size() != x.size() or self.skip_res:
54
+ return x
55
+ else:
56
+ return x + temp
57
+
58
+
59
+ class SpectrogramUnet(nn.Module):
60
+ def __init__(
61
+ self,
62
+ spec_dim: int,
63
+ hidden_dim: int,
64
+ filter_len: int,
65
+ hop_len: int,
66
+ layers: int = 3,
67
+ block_layers: int = 3,
68
+ kernel_size: int = 5,
69
+ is_mask: bool = False,
70
+ norm: str = "bn",
71
+ act: str = "tanh",
72
+ ):
73
+ super().__init__()
74
+ self.layers = layers
75
+ self.is_mask = is_mask
76
+
77
+ # stft modules
78
+ self.stft = STFT(filter_len, hop_len)
79
+
80
+ if norm == "bn":
81
+ self.bn_func = nn.BatchNorm1d
82
+ elif norm == "ins":
83
+ self.bn_func = lambda x: nn.InstanceNorm1d(x, affine=True)
84
+ else:
85
+ raise NotImplementedError("{} is not implemented !".format(norm))
86
+
87
+ if act == "tanh":
88
+ self.act_func = nn.Tanh
89
+ self.act_out = nn.Tanh
90
+ elif act == "comp":
91
+ self.act_func = ComplexActLayer
92
+ self.act_out = lambda: ComplexActLayer(is_out=True)
93
+ else:
94
+ raise NotImplementedError("{} is not implemented !".format(act))
95
+
96
+ # prev conv
97
+ self.prev_conv = ComplexConv1d(spec_dim * 2, hidden_dim, 1)
98
+
99
+ # down
100
+ self.down = nn.ModuleList()
101
+ self.down_pool = nn.MaxPool1d(3, stride=2, padding=1)
102
+ for idx in range(self.layers):
103
+ block = ComplexConvBlock(
104
+ hidden_dim,
105
+ hidden_dim,
106
+ kernel_size=kernel_size,
107
+ padding=kernel_size // 2,
108
+ bn_func=self.bn_func,
109
+ act_func=self.act_func,
110
+ layers=block_layers,
111
+ )
112
+ self.down.append(block)
113
+
114
+ # up
115
+ self.up = nn.ModuleList()
116
+ for idx in range(self.layers):
117
+ in_c = hidden_dim if idx == 0 else hidden_dim * 2
118
+ self.up.append(
119
+ nn.Sequential(
120
+ ComplexConvBlock(
121
+ in_c,
122
+ hidden_dim,
123
+ kernel_size=kernel_size,
124
+ padding=kernel_size // 2,
125
+ bn_func=self.bn_func,
126
+ act_func=self.act_func,
127
+ layers=block_layers,
128
+ ),
129
+ self.bn_func(hidden_dim),
130
+ self.act_func(),
131
+ ComplexTransposedConv1d(
132
+ hidden_dim, hidden_dim, kernel_size=2, stride=2
133
+ ),
134
+ )
135
+ )
136
+
137
+ # out_conv
138
+ self.out_conv = nn.Sequential(
139
+ ComplexConvBlock(
140
+ hidden_dim * 2,
141
+ spec_dim * 2,
142
+ kernel_size=kernel_size,
143
+ padding=kernel_size // 2,
144
+ bn_func=self.bn_func,
145
+ act_func=self.act_func,
146
+ ),
147
+ self.bn_func(spec_dim * 2),
148
+ self.act_func(),
149
+ )
150
+
151
+ # refine conv
152
+ self.refine_conv = nn.Sequential(
153
+ ComplexConvBlock(
154
+ spec_dim * 4,
155
+ spec_dim * 2,
156
+ kernel_size=kernel_size,
157
+ padding=kernel_size // 2,
158
+ bn_func=self.bn_func,
159
+ act_func=self.act_func,
160
+ ),
161
+ self.bn_func(spec_dim * 2),
162
+ self.act_func(),
163
+ )
164
+
165
+ def log_stft(self, wav):
166
+ # stft
167
+ mag, phase = self.stft.transform(wav)
168
+ return torch.log(mag + 1), phase
169
+
170
+ def exp_istft(self, log_mag, phase):
171
+ # exp
172
+ mag = np.e**log_mag - 1
173
+ # istft
174
+ wav = self.stft.inverse(mag, phase)
175
+ return wav
176
+
177
+ def adjust_diff(self, x, target):
178
+ size_diff = target.size()[-1] - x.size()[-1]
179
+ assert size_diff >= 0
180
+ if size_diff > 0:
181
+ x = F.pad(
182
+ x.unsqueeze(1), (size_diff // 2, size_diff // 2), "reflect"
183
+ ).squeeze(1)
184
+ return x
185
+
186
+ def masking(self, mag, phase, origin_mag, origin_phase):
187
+ abs_mag = torch.abs(mag)
188
+ mag_mask = torch.tanh(abs_mag)
189
+ phase_mask = mag / abs_mag
190
+
191
+ # masking
192
+ mag = mag_mask * origin_mag
193
+ phase = phase_mask * (origin_phase + phase)
194
+ return mag, phase
195
+
196
+ def forward(self, wav):
197
+ # stft
198
+ origin_mag, origin_phase = self.log_stft(wav)
199
+ origin_x = torch.cat([origin_mag, origin_phase], dim=1)
200
+
201
+ # prev
202
+ x = self.prev_conv(origin_x)
203
+
204
+ # body
205
+ # down
206
+ down_cache = []
207
+ for idx, block in enumerate(self.down):
208
+ x = block(x)
209
+ down_cache.append(x)
210
+ x = self.down_pool(x)
211
+
212
+ # up
213
+ for idx, block in enumerate(self.up):
214
+ x = block(x)
215
+ res = F.interpolate(
216
+ down_cache[self.layers - (idx + 1)],
217
+ size=[x.size()[2]],
218
+ mode="linear",
219
+ align_corners=False,
220
+ )
221
+ x = concat_complex(x, res, dim=1)
222
+
223
+ # match spec dimension
224
+ x = self.out_conv(x)
225
+ if origin_mag.size(2) != x.size(2):
226
+ x = F.interpolate(
227
+ x, size=[origin_mag.size(2)], mode="linear", align_corners=False
228
+ )
229
+
230
+ # refine
231
+ x = self.refine_conv(concat_complex(x, origin_x))
232
+
233
+ def to_wav(stft):
234
+ mag, phase = stft.chunk(2, 1)
235
+ if self.is_mask:
236
+ mag, phase = self.masking(mag, phase, origin_mag, origin_phase)
237
+ out = self.exp_istft(mag, phase)
238
+ out = self.adjust_diff(out, wav)
239
+ return out
240
+
241
+ refine_wav = to_wav(x)
242
+
243
+ return refine_wav
244
+
245
+
246
+ class RefineSpectrogramUnet(SpectrogramUnet):
247
+ def __init__(
248
+ self,
249
+ spec_dim: int,
250
+ hidden_dim: int,
251
+ filter_len: int,
252
+ hop_len: int,
253
+ layers: int = 4,
254
+ block_layers: int = 4,
255
+ kernel_size: int = 3,
256
+ is_mask: bool = True,
257
+ norm: str = "ins",
258
+ act: str = "comp",
259
+ refine_layers: int = 1,
260
+ add_spec_results: bool = False,
261
+ ):
262
+ super().__init__(
263
+ spec_dim,
264
+ hidden_dim,
265
+ filter_len,
266
+ hop_len,
267
+ layers,
268
+ block_layers,
269
+ kernel_size,
270
+ is_mask,
271
+ norm,
272
+ act,
273
+ )
274
+ self.add_spec_results = add_spec_results
275
+ # refine conv
276
+ self.refine_conv = nn.ModuleList(
277
+ [
278
+ nn.Sequential(
279
+ ComplexConvBlock(
280
+ spec_dim * 2,
281
+ spec_dim * 2,
282
+ kernel_size=kernel_size,
283
+ padding=kernel_size // 2,
284
+ bn_func=self.bn_func,
285
+ act_func=self.act_func,
286
+ ),
287
+ self.bn_func(spec_dim * 2),
288
+ self.act_func(),
289
+ )
290
+ ]
291
+ * refine_layers
292
+ )
293
+
294
+ def forward(self, wav):
295
+ # stft
296
+ origin_mag, origin_phase = self.log_stft(wav)
297
+ origin_x = torch.cat([origin_mag, origin_phase], dim=1)
298
+
299
+ # prev
300
+ x = self.prev_conv(origin_x)
301
+
302
+ # body
303
+ # down
304
+ down_cache = []
305
+ for idx, block in enumerate(self.down):
306
+ x = block(x)
307
+ down_cache.append(x)
308
+ x = self.down_pool(x)
309
+
310
+ # up
311
+ for idx, block in enumerate(self.up):
312
+ x = block(x)
313
+ res = F.interpolate(
314
+ down_cache[self.layers - (idx + 1)],
315
+ size=[x.size()[2]],
316
+ mode="linear",
317
+ align_corners=False,
318
+ )
319
+ x = concat_complex(x, res, dim=1)
320
+
321
+ # match spec dimension
322
+ x = self.out_conv(x)
323
+ if origin_mag.size(2) != x.size(2):
324
+ x = F.interpolate(
325
+ x, size=[origin_mag.size(2)], mode="linear", align_corners=False
326
+ )
327
+
328
+ # refine
329
+ for idx, refine_module in enumerate(self.refine_conv):
330
+ x = refine_module(x)
331
+ mag, phase = x.chunk(2, 1)
332
+ mag, phase = self.masking(mag, phase, origin_mag, origin_phase)
333
+ if idx < len(self.refine_conv) - 1:
334
+ x = torch.cat([mag, phase], dim=1)
335
+
336
+ # clamp phase
337
+ phase = phase.clamp(-np.pi, np.pi)
338
+
339
+ out = self.exp_istft(mag, phase)
340
+ out = self.adjust_diff(out, wav)
341
+
342
+ if self.add_spec_results:
343
+ out = (out, mag, phase)
344
+
345
+ return out
346
+
347
+
348
+ class _ComplexConvNd(nn.Module):
349
+ """
350
+ Implement Complex Convolution
351
+ A: real weight
352
+ B: img weight
353
+ """
354
+
355
+ def __init__(
356
+ self,
357
+ in_channels,
358
+ out_channels,
359
+ kernel_size,
360
+ stride,
361
+ padding,
362
+ dilation,
363
+ transposed,
364
+ output_padding,
365
+ ):
366
+ super().__init__()
367
+ self.in_channels = in_channels
368
+ self.out_channels = out_channels
369
+ self.kernel_size = kernel_size
370
+ self.stride = stride
371
+ self.padding = padding
372
+ self.dilation = dilation
373
+ self.output_padding = output_padding
374
+ self.transposed = transposed
375
+
376
+ self.A = self.make_weight(in_channels, out_channels, kernel_size)
377
+ self.B = self.make_weight(in_channels, out_channels, kernel_size)
378
+
379
+ self.reset_parameters()
380
+
381
+ def make_weight(self, in_ch, out_ch, kernel_size):
382
+ if self.transposed:
383
+ tensor = nn.Parameter(torch.Tensor(in_ch, out_ch // 2, *kernel_size))
384
+ else:
385
+ tensor = nn.Parameter(torch.Tensor(out_ch, in_ch // 2, *kernel_size))
386
+ return tensor
387
+
388
+ def reset_parameters(self):
389
+ # init real weight
390
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.A)
391
+
392
+ # init A
393
+ gain = calculate_gain("leaky_relu", 0)
394
+ std = gain / np.sqrt(fan_in)
395
+ bound = np.sqrt(3.0) * std
396
+
397
+ with torch.no_grad():
398
+ # TODO: find more stable initial values
399
+ self.A.uniform_(-bound * (1 / (np.pi**2)), bound * (1 / (np.pi**2)))
400
+ #
401
+ # B is initialized by pi
402
+ # -pi and pi is too big, so it is powed by -1
403
+ self.B.uniform_(-1 / np.pi, 1 / np.pi)
404
+
405
+
406
+ class ComplexConv1d(_ComplexConvNd):
407
+ """
408
+ Complex Convolution 1d
409
+ """
410
+
411
+ def __init__(
412
+ self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1
413
+ ):
414
+ kernel_size = single(kernel_size)
415
+ stride = single(stride)
416
+ # edit padding
417
+ padding = padding
418
+ dilation = single(dilation)
419
+ super(ComplexConv1d, self).__init__(
420
+ in_channels,
421
+ out_channels,
422
+ kernel_size,
423
+ stride,
424
+ padding,
425
+ dilation,
426
+ False,
427
+ single(0),
428
+ )
429
+
430
+ def forward(self, x):
431
+ """
432
+ Implemented complex convolution using combining 'grouped convolution' and
433
+ 'real / img weight'
434
+ :param x: data (N, C, T) C is concatenated with C/2 real channels and C/2 idea channels
435
+ :return: complex conved result
436
+ """
437
+ # adopt reflect padding
438
+ if self.padding:
439
+ x = F.pad(x, (self.padding, self.padding), "reflect")
440
+
441
+ # forward real
442
+ real_part = F.conv1d(
443
+ x,
444
+ self.A,
445
+ None,
446
+ stride=self.stride,
447
+ padding=0,
448
+ dilation=self.dilation,
449
+ groups=2,
450
+ )
451
+
452
+ # forward idea
453
+ spl = self.in_channels // 2
454
+ weight_B = torch.cat([self.B[:spl].data * (-1), self.B[spl:].data])
455
+ idea_part = F.conv1d(
456
+ x,
457
+ weight_B,
458
+ None,
459
+ stride=self.stride,
460
+ padding=0,
461
+ dilation=self.dilation,
462
+ groups=2,
463
+ )
464
+
465
+ return real_part + idea_part
466
+
467
+
468
+ class ComplexTransposedConv1d(_ComplexConvNd):
469
+ """
470
+ Complex Transposed Convolution 1d
471
+ """
472
+
473
+ def __init__(
474
+ self,
475
+ in_channels,
476
+ out_channels,
477
+ kernel_size,
478
+ stride=1,
479
+ padding=0,
480
+ output_padding=0,
481
+ dilation=1,
482
+ ):
483
+ kernel_size = single(kernel_size)
484
+ stride = single(stride)
485
+ padding = padding
486
+ dilation = single(dilation)
487
+ super().__init__(
488
+ in_channels,
489
+ out_channels,
490
+ kernel_size,
491
+ stride,
492
+ padding,
493
+ dilation,
494
+ True,
495
+ output_padding,
496
+ )
497
+
498
+ def forward(self, x, output_size=None):
499
+ """
500
+ Implemented complex transposed convolution using combining 'grouped convolution'
501
+ and 'real / img weight'
502
+ :param x: data (N, C, T) C is concatenated with C/2 real channels and C/2 idea channels
503
+ :return: complex transposed convolution result
504
+ """
505
+ # forward real
506
+ if self.padding:
507
+ x = F.pad(x, (self.padding, self.padding), "reflect")
508
+
509
+ real_part = F.conv_transpose1d(
510
+ x,
511
+ self.A,
512
+ None,
513
+ stride=self.stride,
514
+ padding=0,
515
+ dilation=self.dilation,
516
+ groups=2,
517
+ )
518
+
519
+ # forward idea
520
+ spl = self.out_channels // 2
521
+ weight_B = torch.cat([self.B[:spl] * (-1), self.B[spl:]])
522
+ idea_part = F.conv_transpose1d(
523
+ x,
524
+ weight_B,
525
+ None,
526
+ stride=self.stride,
527
+ padding=0,
528
+ dilation=self.dilation,
529
+ groups=2,
530
+ )
531
+
532
+ if self.output_padding:
533
+ real_part = F.pad(
534
+ real_part, (self.output_padding, self.output_padding), "reflect"
535
+ )
536
+ idea_part = F.pad(
537
+ idea_part, (self.output_padding, self.output_padding), "reflect"
538
+ )
539
+
540
+ return real_part + idea_part
541
+
542
+
543
+ class ComplexActLayer(nn.Module):
544
+ """
545
+ Activation differently 'real' part and 'img' part
546
+ In implemented DCUnet on this repository, Real part is activated to log space.
547
+ And Phase(img) part, it is distributed in [-pi, pi]...
548
+ """
549
+
550
+ def forward(self, x):
551
+ real, img = x.chunk(2, 1)
552
+ return torch.cat([F.leaky_relu_(real), torch.tanh(img) * np.pi], dim=1)
553
+
554
+
555
+ class STFT(nn.Module):
556
+ """
557
+ Re-construct stft for calculating backward operation
558
+ refer on : https://github.com/pseeth/torch-stft/blob/master/torch_stft/stft.py
559
+ """
560
+
561
+ def __init__(
562
+ self,
563
+ filter_length: int = 1024,
564
+ hop_length: int = 512,
565
+ win_length: int = None,
566
+ window: str = "hann",
567
+ ):
568
+ super().__init__()
569
+ self.filter_length = filter_length
570
+ self.hop_length = hop_length
571
+ self.win_length = win_length if win_length else filter_length
572
+ self.window = window
573
+ self.pad_amount = self.filter_length // 2
574
+
575
+ # make fft window
576
+ assert filter_length >= self.win_length
577
+ # get window and zero center pad it to filter_length
578
+ fft_window = get_window(window, self.win_length, fftbins=True)
579
+ fft_window = pad_center(fft_window, filter_length)
580
+ fft_window = torch.from_numpy(fft_window).float()
581
+
582
+ # calculate fourer_basis
583
+ cut_off = int((self.filter_length / 2 + 1))
584
+ fourier_basis = np.fft.fft(np.eye(self.filter_length))
585
+ fourier_basis = np.vstack(
586
+ [np.real(fourier_basis[:cut_off, :]), np.imag(fourier_basis[:cut_off, :])]
587
+ )
588
+
589
+ # make forward & inverse basis
590
+ self.register_buffer("square_window", fft_window**2)
591
+
592
+ forward_basis = torch.FloatTensor(fourier_basis[:, np.newaxis, :]) * fft_window
593
+ inverse_basis = (
594
+ torch.FloatTensor(
595
+ np.linalg.pinv(self.filter_length / self.hop_length * fourier_basis).T[
596
+ :, np.newaxis, :
597
+ ]
598
+ )
599
+ * fft_window
600
+ )
601
+ # torch.pinverse has a bug, so at this time, it is separated into two parts..
602
+ self.register_buffer("forward_basis", forward_basis)
603
+ self.register_buffer("inverse_basis", inverse_basis)
604
+
605
+ def transform(self, wav: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
606
+ # reflect padding
607
+ wav = wav.unsqueeze(1).unsqueeze(1)
608
+ wav = F.pad(
609
+ wav, (self.pad_amount, self.pad_amount, 0, 0), mode="reflect"
610
+ ).squeeze(1)
611
+
612
+ # conv
613
+ forward_trans = F.conv1d(
614
+ wav, self.forward_basis, stride=self.hop_length, padding=0
615
+ )
616
+ real_part, imag_part = forward_trans.chunk(2, 1)
617
+
618
+ return torch.sqrt(real_part**2 + imag_part**2), torch.atan2(
619
+ imag_part.data, real_part.data
620
+ )
621
+
622
+ def inverse(
623
+ self, magnitude: torch.Tensor, phase: torch.Tensor, eps: float = 1e-9
624
+ ) -> torch.Tensor:
625
+ comp = torch.cat(
626
+ [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1
627
+ )
628
+ inverse_transform = F.conv_transpose1d(
629
+ comp, self.inverse_basis, stride=self.hop_length, padding=0
630
+ )
631
+
632
+ # remove window effect
633
+ n_frames = comp.size(-1)
634
+ inverse_size = inverse_transform.size(-1)
635
+
636
+ window_filter = torch.ones(1, 1, n_frames).type_as(inverse_transform)
637
+
638
+ weight = self.square_window[: self.filter_length].unsqueeze(0).unsqueeze(0)
639
+ window_filter = F.conv_transpose1d(
640
+ window_filter, weight, stride=self.hop_length, padding=0
641
+ )
642
+ window_filter = window_filter.squeeze()[:inverse_size] + eps
643
+
644
+ inverse_transform /= window_filter
645
+
646
+ # scale by hop ratio
647
+ inverse_transform *= self.filter_length / self.hop_length
648
+
649
+ return inverse_transform[..., self.pad_amount : -self.pad_amount].squeeze(1)
remfx/dptnet.py ADDED
@@ -0,0 +1,460 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.nn.modules.container import ModuleList
5
+ from torch.nn.modules.activation import MultiheadAttention
6
+ from torch.nn.modules.dropout import Dropout
7
+ from torch.nn.modules.linear import Linear
8
+ from torch.nn.modules.rnn import LSTM
9
+ from torch.nn.modules.normalization import LayerNorm
10
+ from torch.autograd import Variable
11
+ import copy
12
+ import math
13
+
14
+
15
+ # adapted from https://github.com/ujscjj/DPTNet
16
+
17
+
18
+ class DPTNet_base(nn.Module):
19
+ def __init__(
20
+ self,
21
+ enc_dim,
22
+ feature_dim,
23
+ hidden_dim,
24
+ layer,
25
+ segment_size=250,
26
+ nspk=2,
27
+ win_len=2,
28
+ ):
29
+ super().__init__()
30
+ # parameters
31
+ self.window = win_len
32
+ self.stride = self.window // 2
33
+
34
+ self.enc_dim = enc_dim
35
+ self.feature_dim = feature_dim
36
+ self.hidden_dim = hidden_dim
37
+ self.segment_size = segment_size
38
+
39
+ self.layer = layer
40
+ self.num_spk = nspk
41
+ self.eps = 1e-8
42
+
43
+ self.dpt_encoder = DPTEncoder(
44
+ n_filters=enc_dim,
45
+ window_size=win_len,
46
+ )
47
+ self.enc_LN = nn.GroupNorm(1, self.enc_dim, eps=1e-8)
48
+ self.dpt_separation = DPTSeparation(
49
+ self.enc_dim,
50
+ self.feature_dim,
51
+ self.hidden_dim,
52
+ self.num_spk,
53
+ self.layer,
54
+ self.segment_size,
55
+ )
56
+
57
+ self.mask_conv1x1 = nn.Conv1d(self.feature_dim, self.enc_dim, 1, bias=False)
58
+ self.decoder = DPTDecoder(n_filters=enc_dim, window_size=win_len)
59
+
60
+ def forward(self, batch):
61
+ """
62
+ mix: shape (batch, T)
63
+ """
64
+ mix, target = batch
65
+ batch_size = mix.shape[0]
66
+ mix = self.dpt_encoder(mix) # (B, E, L)
67
+
68
+ score_ = self.enc_LN(mix) # B, E, L
69
+ score_ = self.dpt_separation(score_) # B, nspk, T, N
70
+ score_ = (
71
+ score_.view(batch_size * self.num_spk, -1, self.feature_dim)
72
+ .transpose(1, 2)
73
+ .contiguous()
74
+ ) # B*nspk, N, T
75
+ score = self.mask_conv1x1(score_) # [B*nspk, N, L] -> [B*nspk, E, L]
76
+ score = score.view(
77
+ batch_size, self.num_spk, self.enc_dim, -1
78
+ ) # [B*nspk, E, L] -> [B, nspk, E, L]
79
+ est_mask = F.relu(score)
80
+
81
+ est_source = self.decoder(
82
+ mix, est_mask
83
+ ) # [B, E, L] + [B, nspk, E, L]--> [B, nspk, T]
84
+
85
+ return est_source
86
+
87
+
88
+ class DPTEncoder(nn.Module):
89
+ def __init__(self, n_filters: int = 64, window_size: int = 2):
90
+ super().__init__()
91
+ self.conv = nn.Conv1d(
92
+ 1, n_filters, kernel_size=window_size, stride=window_size // 2, bias=False
93
+ )
94
+
95
+ def forward(self, x):
96
+ x = x.unsqueeze(1)
97
+ x = F.relu(self.conv(x))
98
+ return x
99
+
100
+
101
+ class TransformerEncoderLayer(torch.nn.Module):
102
+ def __init__(
103
+ self, d_model, nhead, hidden_size, dim_feedforward, dropout, activation="relu"
104
+ ):
105
+ super(TransformerEncoderLayer, self).__init__()
106
+ self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
107
+
108
+ # Implementation of improved part
109
+ self.lstm = LSTM(d_model, hidden_size, 1, bidirectional=True)
110
+ self.dropout = Dropout(dropout)
111
+ self.linear = Linear(hidden_size * 2, d_model)
112
+
113
+ self.norm1 = LayerNorm(d_model)
114
+ self.norm2 = LayerNorm(d_model)
115
+ self.dropout1 = Dropout(dropout)
116
+ self.dropout2 = Dropout(dropout)
117
+
118
+ self.activation = _get_activation_fn(activation)
119
+
120
+ def __setstate__(self, state):
121
+ if "activation" not in state:
122
+ state["activation"] = F.relu
123
+ super(TransformerEncoderLayer, self).__setstate__(state)
124
+
125
+ def forward(self, src, src_mask=None, src_key_padding_mask=None):
126
+ r"""Pass the input through the encoder layer.
127
+ Args:
128
+ src: the sequnce to the encoder layer (required).
129
+ src_mask: the mask for the src sequence (optional).
130
+ src_key_padding_mask: the mask for the src keys per batch (optional).
131
+ Shape:
132
+ see the docs in Transformer class.
133
+ """
134
+ src2 = self.self_attn(
135
+ src, src, src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
136
+ )[0]
137
+ src = src + self.dropout1(src2)
138
+ src = self.norm1(src)
139
+ src2 = self.linear(self.dropout(self.activation(self.lstm(src)[0])))
140
+ src = src + self.dropout2(src2)
141
+ src = self.norm2(src)
142
+ return src
143
+
144
+
145
+ def _get_clones(module, N):
146
+ return ModuleList([copy.deepcopy(module) for i in range(N)])
147
+
148
+
149
+ def _get_activation_fn(activation):
150
+ if activation == "relu":
151
+ return F.relu
152
+ elif activation == "gelu":
153
+ return F.gelu
154
+
155
+ raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
156
+
157
+
158
+ class SingleTransformer(nn.Module):
159
+ """
160
+ Container module for a single Transformer layer.
161
+ args: input_size: int, dimension of the input feature.
162
+ The input should have shape (batch, seq_len, input_size).
163
+ """
164
+
165
+ def __init__(self, input_size, hidden_size, dropout):
166
+ super(SingleTransformer, self).__init__()
167
+ self.transformer = TransformerEncoderLayer(
168
+ d_model=input_size,
169
+ nhead=4,
170
+ hidden_size=hidden_size,
171
+ dim_feedforward=hidden_size * 2,
172
+ dropout=dropout,
173
+ )
174
+
175
+ def forward(self, input):
176
+ # input shape: batch, seq, dim
177
+ output = input
178
+ transformer_output = (
179
+ self.transformer(output.permute(1, 0, 2).contiguous())
180
+ .permute(1, 0, 2)
181
+ .contiguous()
182
+ )
183
+ return transformer_output
184
+
185
+
186
+ # dual-path transformer
187
+ class DPT(nn.Module):
188
+ """
189
+ Deep dual-path transformer.
190
+ args:
191
+ input_size: int, dimension of the input feature. The input should have shape
192
+ (batch, seq_len, input_size).
193
+ hidden_size: int, dimension of the hidden state.
194
+ output_size: int, dimension of the output size.
195
+ num_layers: int, number of stacked Transformer layers. Default is 1.
196
+ dropout: float, dropout ratio. Default is 0.
197
+ """
198
+
199
+ def __init__(self, input_size, hidden_size, output_size, num_layers=1, dropout=0):
200
+ super(DPT, self).__init__()
201
+
202
+ self.input_size = input_size
203
+ self.output_size = output_size
204
+ self.hidden_size = hidden_size
205
+
206
+ # dual-path transformer
207
+ self.row_transformer = nn.ModuleList([])
208
+ self.col_transformer = nn.ModuleList([])
209
+ for i in range(num_layers):
210
+ self.row_transformer.append(
211
+ SingleTransformer(input_size, hidden_size, dropout)
212
+ )
213
+ self.col_transformer.append(
214
+ SingleTransformer(input_size, hidden_size, dropout)
215
+ )
216
+
217
+ # output layer
218
+ self.output = nn.Sequential(nn.PReLU(), nn.Conv2d(input_size, output_size, 1))
219
+
220
+ def forward(self, input):
221
+ # input shape: batch, N, dim1, dim2
222
+ # apply transformer on dim1 first and then dim2
223
+ # output shape: B, output_size, dim1, dim2
224
+ # input = input.to(device)
225
+ batch_size, _, dim1, dim2 = input.shape
226
+ output = input
227
+ for i in range(len(self.row_transformer)):
228
+ row_input = (
229
+ output.permute(0, 3, 2, 1)
230
+ .contiguous()
231
+ .view(batch_size * dim2, dim1, -1)
232
+ ) # B*dim2, dim1, N
233
+ row_output = self.row_transformer[i](row_input) # B*dim2, dim1, H
234
+ row_output = (
235
+ row_output.view(batch_size, dim2, dim1, -1)
236
+ .permute(0, 3, 2, 1)
237
+ .contiguous()
238
+ ) # B, N, dim1, dim2
239
+ output = row_output
240
+
241
+ col_input = (
242
+ output.permute(0, 2, 3, 1)
243
+ .contiguous()
244
+ .view(batch_size * dim1, dim2, -1)
245
+ ) # B*dim1, dim2, N
246
+ col_output = self.col_transformer[i](col_input) # B*dim1, dim2, H
247
+ col_output = (
248
+ col_output.view(batch_size, dim1, dim2, -1)
249
+ .permute(0, 3, 1, 2)
250
+ .contiguous()
251
+ ) # B, N, dim1, dim2
252
+ output = col_output
253
+
254
+ output = self.output(output) # B, output_size, dim1, dim2
255
+
256
+ return output
257
+
258
+
259
+ # base module for deep DPT
260
+ class DPT_base(nn.Module):
261
+ def __init__(
262
+ self, input_dim, feature_dim, hidden_dim, num_spk=2, layer=6, segment_size=250
263
+ ):
264
+ super(DPT_base, self).__init__()
265
+
266
+ self.input_dim = input_dim
267
+ self.feature_dim = feature_dim
268
+ self.hidden_dim = hidden_dim
269
+
270
+ self.layer = layer
271
+ self.segment_size = segment_size
272
+ self.num_spk = num_spk
273
+
274
+ self.eps = 1e-8
275
+
276
+ # bottleneck
277
+ self.BN = nn.Conv1d(self.input_dim, self.feature_dim, 1, bias=False)
278
+
279
+ # DPT model
280
+ self.DPT = DPT(
281
+ self.feature_dim,
282
+ self.hidden_dim,
283
+ self.feature_dim * self.num_spk,
284
+ num_layers=layer,
285
+ )
286
+
287
+ def pad_segment(self, input, segment_size):
288
+ # input is the features: (B, N, T)
289
+ batch_size, dim, seq_len = input.shape
290
+ segment_stride = segment_size // 2
291
+
292
+ rest = segment_size - (segment_stride + seq_len % segment_size) % segment_size
293
+ if rest > 0:
294
+ pad = Variable(torch.zeros(batch_size, dim, rest)).type(input.type())
295
+ input = torch.cat([input, pad], 2)
296
+
297
+ pad_aux = Variable(torch.zeros(batch_size, dim, segment_stride)).type(
298
+ input.type()
299
+ )
300
+ input = torch.cat([pad_aux, input, pad_aux], 2)
301
+
302
+ return input, rest
303
+
304
+ def split_feature(self, input, segment_size):
305
+ # split the feature into chunks of segment size
306
+ # input is the features: (B, N, T)
307
+
308
+ input, rest = self.pad_segment(input, segment_size)
309
+ batch_size, dim, seq_len = input.shape
310
+ segment_stride = segment_size // 2
311
+
312
+ segments1 = (
313
+ input[:, :, :-segment_stride]
314
+ .contiguous()
315
+ .view(batch_size, dim, -1, segment_size)
316
+ )
317
+ segments2 = (
318
+ input[:, :, segment_stride:]
319
+ .contiguous()
320
+ .view(batch_size, dim, -1, segment_size)
321
+ )
322
+ segments = (
323
+ torch.cat([segments1, segments2], 3)
324
+ .view(batch_size, dim, -1, segment_size)
325
+ .transpose(2, 3)
326
+ )
327
+
328
+ return segments.contiguous(), rest
329
+
330
+ def merge_feature(self, input, rest):
331
+ # merge the splitted features into full utterance
332
+ # input is the features: (B, N, L, K)
333
+
334
+ batch_size, dim, segment_size, _ = input.shape
335
+ segment_stride = segment_size // 2
336
+ input = (
337
+ input.transpose(2, 3)
338
+ .contiguous()
339
+ .view(batch_size, dim, -1, segment_size * 2)
340
+ ) # B, N, K, L
341
+
342
+ input1 = (
343
+ input[:, :, :, :segment_size]
344
+ .contiguous()
345
+ .view(batch_size, dim, -1)[:, :, segment_stride:]
346
+ )
347
+ input2 = (
348
+ input[:, :, :, segment_size:]
349
+ .contiguous()
350
+ .view(batch_size, dim, -1)[:, :, :-segment_stride]
351
+ )
352
+
353
+ output = input1 + input2
354
+ if rest > 0:
355
+ output = output[:, :, :-rest]
356
+
357
+ return output.contiguous() # B, N, T
358
+
359
+ def forward(self, input):
360
+ pass
361
+
362
+
363
+ class DPTSeparation(DPT_base):
364
+ def __init__(self, *args, **kwargs):
365
+ super(DPTSeparation, self).__init__(*args, **kwargs)
366
+
367
+ # gated output layer
368
+ self.output = nn.Sequential(
369
+ nn.Conv1d(self.feature_dim, self.feature_dim, 1), nn.Tanh()
370
+ )
371
+ self.output_gate = nn.Sequential(
372
+ nn.Conv1d(self.feature_dim, self.feature_dim, 1), nn.Sigmoid()
373
+ )
374
+
375
+ def forward(self, input):
376
+ # input = input.to(device)
377
+ # input: (B, E, T)
378
+ batch_size, E, seq_length = input.shape
379
+
380
+ enc_feature = self.BN(input) # (B, E, L)-->(B, N, L)
381
+ # split the encoder output into overlapped, longer segments
382
+ enc_segments, enc_rest = self.split_feature(
383
+ enc_feature, self.segment_size
384
+ ) # B, N, L, K: L is the segment_size
385
+ # print('enc_segments.shape {}'.format(enc_segments.shape))
386
+ # pass to DPT
387
+ output = self.DPT(enc_segments).view(
388
+ batch_size * self.num_spk, self.feature_dim, self.segment_size, -1
389
+ ) # B*nspk, N, L, K
390
+
391
+ # overlap-and-add of the outputs
392
+ output = self.merge_feature(output, enc_rest) # B*nspk, N, T
393
+
394
+ # gated output layer for filter generation
395
+ bf_filter = self.output(output) * self.output_gate(output) # B*nspk, K, T
396
+ bf_filter = (
397
+ bf_filter.transpose(1, 2)
398
+ .contiguous()
399
+ .view(batch_size, self.num_spk, -1, self.feature_dim)
400
+ ) # B, nspk, T, N
401
+
402
+ return bf_filter
403
+
404
+
405
+ class DPTDecoder(nn.Module):
406
+ def __init__(self, n_filters: int = 64, window_size: int = 2):
407
+ super().__init__()
408
+ self.W = window_size
409
+ self.basis_signals = nn.Linear(n_filters, window_size, bias=False)
410
+
411
+ def forward(self, mixture, mask):
412
+ """
413
+ mixture: (batch, n_filters, L)
414
+ mask: (batch, sources, n_filters, L)
415
+ """
416
+ source_w = torch.unsqueeze(mixture, 1) * mask # [B, C, E, L]
417
+ source_w = torch.transpose(source_w, 2, 3) # [B, C, L, E]
418
+ # S = DV
419
+ est_source = self.basis_signals(source_w) # [B, C, L, W]
420
+ est_source = overlap_and_add(est_source, self.W // 2) # B x C x T
421
+ return est_source
422
+
423
+
424
+ def overlap_and_add(signal, frame_step):
425
+ """Reconstructs a signal from a framed representation.
426
+ Adds potentially overlapping frames of a signal with shape
427
+ `[..., frames, frame_length]`, offsetting subsequent frames by `frame_step`.
428
+ The resulting tensor has shape `[..., output_size]` where
429
+ output_size = (frames - 1) * frame_step + frame_length
430
+ Args:
431
+ signal: A [..., frames, frame_length] Tensor.
432
+ All dimensions may be unknown, and rank must be at least 2.
433
+ frame_step: An integer denoting overlap offsets. Must be less than or equal to frame_length.
434
+ Returns:
435
+ A Tensor with shape [..., output_size] containing the overlap-added frames of signal's
436
+ inner-most two dimensions.
437
+ output_size = (frames - 1) * frame_step + frame_length
438
+ Based on https://github.com/tensorflow/tensorflow/blob/r1.12/tensorflow/contrib/signal/python/ops/reconstruction_ops.py
439
+ """
440
+ outer_dimensions = signal.size()[:-2]
441
+ frames, frame_length = signal.size()[-2:]
442
+
443
+ subframe_length = math.gcd(frame_length, frame_step) # gcd=Greatest Common Divisor
444
+ subframe_step = frame_step // subframe_length
445
+ subframes_per_frame = frame_length // subframe_length
446
+ output_size = frame_step * (frames - 1) + frame_length
447
+ output_subframes = output_size // subframe_length
448
+
449
+ subframe_signal = signal.reshape(*outer_dimensions, -1, subframe_length)
450
+
451
+ frame = torch.arange(0, output_subframes).unfold(
452
+ 0, subframes_per_frame, subframe_step
453
+ )
454
+ frame = signal.new_tensor(frame).long() # signal may in GPU or CPU
455
+ frame = frame.contiguous().view(-1)
456
+
457
+ result = signal.new_zeros(*outer_dimensions, output_subframes, subframe_length)
458
+ result.index_add_(-2, frame, subframe_signal)
459
+ result = result.view(*outer_dimensions, -1)
460
+ return result
remfx/models.py CHANGED
@@ -1,10 +1,6 @@
1
- import wandb
2
  import torch
3
- import torchaudio
4
  import torchmetrics
5
  import pytorch_lightning as pl
6
- import torch.nn.functional as F
7
-
8
  from torch import Tensor, nn
9
  from einops import rearrange
10
  from torchaudio.models import HDemucs
@@ -13,10 +9,12 @@ from auraloss.time import SISDRLoss
13
  from auraloss.freq import MultiResolutionSTFTLoss
14
  from umx.openunmix.model import OpenUnmix, Separator
15
 
16
- from remfx.utils import FADLoss
 
 
17
 
18
 
19
- class RemFXModel(pl.LightningModule):
20
  def __init__(
21
  self,
22
  lr: float,
@@ -35,7 +33,7 @@ class RemFXModel(pl.LightningModule):
35
  self.lr_weight_decay = lr_weight_decay
36
  self.sample_rate = sample_rate
37
  self.model = network
38
- self.metrics = torch.nn.ModuleDict(
39
  {
40
  "SISDR": SISDRLoss(),
41
  "STFT": MultiResolutionSTFTLoss(),
@@ -94,7 +92,8 @@ class RemFXModel(pl.LightningModule):
94
  return loss
95
 
96
  def common_step(self, batch, batch_idx, mode: str = "train"):
97
- x, y, _, _ = batch
 
98
  loss, output = self.model((x, y))
99
  self.log(f"{mode}_loss", loss)
100
  # Metric logging
@@ -201,7 +200,7 @@ class RemFXModel(pl.LightningModule):
201
  )
202
 
203
 
204
- class OpenUnmixModel(torch.nn.Module):
205
  def __init__(
206
  self,
207
  n_fft: int = 2048,
@@ -234,7 +233,7 @@ class OpenUnmixModel(torch.nn.Module):
234
  self.mrstftloss = MultiResolutionSTFTLoss(
235
  n_bins=self.num_bins, sample_rate=self.sample_rate
236
  )
237
- self.l1loss = torch.nn.L1Loss()
238
 
239
  def forward(self, batch):
240
  x, target = batch
@@ -249,7 +248,7 @@ class OpenUnmixModel(torch.nn.Module):
249
  return self.separator(x).squeeze(1)
250
 
251
 
252
- class DemucsModel(torch.nn.Module):
253
  def __init__(self, sample_rate, **kwargs) -> None:
254
  super().__init__()
255
  self.model = HDemucs(**kwargs)
@@ -257,7 +256,7 @@ class DemucsModel(torch.nn.Module):
257
  self.mrstftloss = MultiResolutionSTFTLoss(
258
  n_bins=self.num_bins, sample_rate=sample_rate
259
  )
260
- self.l1loss = torch.nn.L1Loss()
261
 
262
  def forward(self, batch):
263
  x, target = batch
@@ -284,201 +283,42 @@ class DiffusionGenerationModel(nn.Module):
284
  return self.model.sample(noise, num_steps=num_steps)
285
 
286
 
287
- def log_wandb_audio_batch(
288
- logger: pl.loggers.WandbLogger,
289
- id: str,
290
- samples: Tensor,
291
- sampling_rate: int,
292
- caption: str = "",
293
- max_items: int = 10,
294
- ):
295
- num_items = samples.shape[0]
296
- samples = rearrange(samples, "b c t -> b t c")
297
- for idx in range(num_items):
298
- if idx >= max_items:
299
- break
300
- logger.experiment.log(
301
- {
302
- f"{id}_{idx}": wandb.Audio(
303
- samples[idx].cpu().numpy(),
304
- caption=caption,
305
- sample_rate=sampling_rate,
306
- )
307
- }
308
- )
309
-
310
-
311
- def spectrogram(
312
- x: torch.Tensor,
313
- window: torch.Tensor,
314
- n_fft: int,
315
- hop_length: int,
316
- alpha: float,
317
- ) -> torch.Tensor:
318
- bs, chs, samp = x.size()
319
- x = x.view(bs * chs, -1) # move channels onto batch dim
320
-
321
- X = torch.stft(
322
- x,
323
- n_fft=n_fft,
324
- hop_length=hop_length,
325
- window=window,
326
- return_complex=True,
327
- )
328
-
329
- # move channels back
330
- X = X.view(bs, chs, X.shape[-2], X.shape[-1])
331
-
332
- return torch.pow(X.abs() + 1e-8, alpha)
333
-
334
-
335
- # adapted from https://github.com/qiuqiangkong/audioset_tagging_cnn/blob/master/pytorch/models.py
336
-
337
-
338
- def init_layer(layer):
339
- """Initialize a Linear or Convolutional layer."""
340
- nn.init.xavier_uniform_(layer.weight)
341
-
342
- if hasattr(layer, "bias"):
343
- if layer.bias is not None:
344
- layer.bias.data.fill_(0.0)
345
-
346
-
347
- def init_bn(bn):
348
- """Initialize a Batchnorm layer."""
349
- bn.bias.data.fill_(0.0)
350
- bn.weight.data.fill_(1.0)
351
-
352
-
353
- class ConvBlock(nn.Module):
354
- def __init__(self, in_channels, out_channels):
355
- super(ConvBlock, self).__init__()
356
-
357
- self.conv1 = nn.Conv2d(
358
- in_channels=in_channels,
359
- out_channels=out_channels,
360
- kernel_size=(3, 3),
361
- stride=(1, 1),
362
- padding=(1, 1),
363
- bias=False,
364
- )
365
-
366
- self.conv2 = nn.Conv2d(
367
- in_channels=out_channels,
368
- out_channels=out_channels,
369
- kernel_size=(3, 3),
370
- stride=(1, 1),
371
- padding=(1, 1),
372
- bias=False,
373
  )
 
374
 
375
- self.bn1 = nn.BatchNorm2d(out_channels)
376
- self.bn2 = nn.BatchNorm2d(out_channels)
377
-
378
- self.init_weight()
379
-
380
- def init_weight(self):
381
- init_layer(self.conv1)
382
- init_layer(self.conv2)
383
- init_bn(self.bn1)
384
- init_bn(self.bn2)
385
-
386
- def forward(self, input, pool_size=(2, 2), pool_type="avg"):
387
- x = input
388
- x = F.relu_(self.bn1(self.conv1(x)))
389
- x = F.relu_(self.bn2(self.conv2(x)))
390
- if pool_type == "max":
391
- x = F.max_pool2d(x, kernel_size=pool_size)
392
- elif pool_type == "avg":
393
- x = F.avg_pool2d(x, kernel_size=pool_size)
394
- elif pool_type == "avg+max":
395
- x1 = F.avg_pool2d(x, kernel_size=pool_size)
396
- x2 = F.max_pool2d(x, kernel_size=pool_size)
397
- x = x1 + x2
398
- else:
399
- raise Exception("Incorrect argument!")
400
 
401
- return x
 
402
 
403
 
404
- class Cnn14(nn.Module):
405
- def __init__(
406
- self,
407
- num_classes: int,
408
- sample_rate: float,
409
- n_fft: int = 2048,
410
- hop_length: int = 512,
411
- n_mels: int = 128,
412
- ):
413
  super().__init__()
414
- self.num_classes = num_classes
415
- self.n_fft = n_fft
416
- self.hop_length = hop_length
417
-
418
- window = torch.hann_window(n_fft)
419
- self.register_buffer("window", window)
420
-
421
- self.melspec = torchaudio.transforms.MelSpectrogram(
422
- sample_rate,
423
- n_fft,
424
- hop_length=hop_length,
425
- n_mels=n_mels,
426
  )
 
427
 
428
- self.bn0 = nn.BatchNorm2d(n_mels)
429
-
430
- self.conv_block1 = ConvBlock(in_channels=1, out_channels=64)
431
- self.conv_block2 = ConvBlock(in_channels=64, out_channels=128)
432
- self.conv_block3 = ConvBlock(in_channels=128, out_channels=256)
433
- self.conv_block4 = ConvBlock(in_channels=256, out_channels=512)
434
- self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024)
435
- self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048)
436
-
437
- self.fc1 = nn.Linear(2048, 2048, bias=True)
438
- self.fc_audioset = nn.Linear(2048, num_classes, bias=True)
439
-
440
- self.init_weight()
441
-
442
- def init_weight(self):
443
- init_bn(self.bn0)
444
- init_layer(self.fc1)
445
- init_layer(self.fc_audioset)
446
 
447
- def forward(self, x: torch.Tensor):
448
- """
449
- Input: (batch_size, data_length)"""
450
-
451
- x = self.melspec(x)
452
- x = x.permute(0, 2, 1, 3)
453
- x = self.bn0(x)
454
- x = x.permute(0, 2, 1, 3)
455
-
456
- if self.training:
457
- pass
458
- # x = self.spec_augmenter(x)
459
-
460
- x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg")
461
- x = F.dropout(x, p=0.2, training=self.training)
462
- x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg")
463
- x = F.dropout(x, p=0.2, training=self.training)
464
- x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg")
465
- x = F.dropout(x, p=0.2, training=self.training)
466
- x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg")
467
- x = F.dropout(x, p=0.2, training=self.training)
468
- x = self.conv_block5(x, pool_size=(2, 2), pool_type="avg")
469
- x = F.dropout(x, p=0.2, training=self.training)
470
- x = self.conv_block6(x, pool_size=(1, 1), pool_type="avg")
471
- x = F.dropout(x, p=0.2, training=self.training)
472
- x = torch.mean(x, dim=3)
473
-
474
- (x1, _) = torch.max(x, dim=2)
475
- x2 = torch.mean(x, dim=2)
476
- x = x1 + x2
477
- x = F.dropout(x, p=0.5, training=self.training)
478
- x = F.relu_(self.fc1(x))
479
- clipwise_output = self.fc_audioset(x)
480
-
481
- return clipwise_output
482
 
483
 
484
  class FXClassifier(pl.LightningModule):
@@ -501,7 +341,7 @@ class FXClassifier(pl.LightningModule):
501
  def common_step(self, batch, batch_idx, mode: str = "train"):
502
  x, y, dry_label, wet_label = batch
503
  pred_label = self.network(x)
504
- loss = torch.nn.functional.cross_entropy(pred_label, dry_label)
505
  self.log(
506
  f"{mode}_loss",
507
  loss,
 
 
1
  import torch
 
2
  import torchmetrics
3
  import pytorch_lightning as pl
 
 
4
  from torch import Tensor, nn
5
  from einops import rearrange
6
  from torchaudio.models import HDemucs
 
9
  from auraloss.freq import MultiResolutionSTFTLoss
10
  from umx.openunmix.model import OpenUnmix, Separator
11
 
12
+ from utils import FADLoss, spectrogram, log_wandb_audio_batch
13
+ from dptnet import DPTNet_base
14
+ from dcunet import RefineSpectrogramUnet
15
 
16
 
17
+ class RemFX(pl.LightningModule):
18
  def __init__(
19
  self,
20
  lr: float,
 
33
  self.lr_weight_decay = lr_weight_decay
34
  self.sample_rate = sample_rate
35
  self.model = network
36
+ self.metrics = nn.ModuleDict(
37
  {
38
  "SISDR": SISDRLoss(),
39
  "STFT": MultiResolutionSTFTLoss(),
 
92
  return loss
93
 
94
  def common_step(self, batch, batch_idx, mode: str = "train"):
95
+ x, y, _, _ = batch # x, y = (B, C, T), (B, C, T)
96
+
97
  loss, output = self.model((x, y))
98
  self.log(f"{mode}_loss", loss)
99
  # Metric logging
 
200
  )
201
 
202
 
203
+ class OpenUnmixModel(nn.Module):
204
  def __init__(
205
  self,
206
  n_fft: int = 2048,
 
233
  self.mrstftloss = MultiResolutionSTFTLoss(
234
  n_bins=self.num_bins, sample_rate=self.sample_rate
235
  )
236
+ self.l1loss = nn.L1Loss()
237
 
238
  def forward(self, batch):
239
  x, target = batch
 
248
  return self.separator(x).squeeze(1)
249
 
250
 
251
+ class DemucsModel(nn.Module):
252
  def __init__(self, sample_rate, **kwargs) -> None:
253
  super().__init__()
254
  self.model = HDemucs(**kwargs)
 
256
  self.mrstftloss = MultiResolutionSTFTLoss(
257
  n_bins=self.num_bins, sample_rate=sample_rate
258
  )
259
+ self.l1loss = nn.L1Loss()
260
 
261
  def forward(self, batch):
262
  x, target = batch
 
283
  return self.model.sample(noise, num_steps=num_steps)
284
 
285
 
286
+ class DPTNetModel(nn.Module):
287
+ def __init__(self, sample_rate, **kwargs):
288
+ super().__init__()
289
+ self.model = DPTNet_base(**kwargs)
290
+ self.mrstftloss = MultiResolutionSTFTLoss(
291
+ n_bins=self.num_bins, sample_rate=sample_rate
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
292
  )
293
+ self.l1loss = nn.L1Loss()
294
 
295
+ def forward(self, batch):
296
+ x, target = batch
297
+ output = self.model(x).squeeze(1)
298
+ loss = self.mrstftloss(output, target) + self.l1loss(output, target) * 100
299
+ return loss, output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
300
 
301
+ def sample(self, x: Tensor) -> Tensor:
302
+ return self.model.sample(x)
303
 
304
 
305
+ class DCUNetModel(nn.Module):
306
+ def __init__(self, sample_rate, **kwargs):
 
 
 
 
 
 
 
307
  super().__init__()
308
+ self.model = RefineSpectrogramUnet(**kwargs)
309
+ self.mrstftloss = MultiResolutionSTFTLoss(
310
+ n_bins=self.num_bins, sample_rate=sample_rate
 
 
 
 
 
 
 
 
 
311
  )
312
+ self.l1loss = nn.L1Loss()
313
 
314
+ def forward(self, batch):
315
+ x, target = batch
316
+ output = self.model(x).squeeze(1)
317
+ loss = self.mrstftloss(output, target) + self.l1loss(output, target) * 100
318
+ return loss, output
 
 
 
 
 
 
 
 
 
 
 
 
 
319
 
320
+ def sample(self, x: Tensor) -> Tensor:
321
+ return self.model.sample(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
322
 
323
 
324
  class FXClassifier(pl.LightningModule):
 
341
  def common_step(self, batch, batch_idx, mode: str = "train"):
342
  x, y, dry_label, wet_label = batch
343
  pred_label = self.network(x)
344
+ loss = nn.functional.cross_entropy(pred_label, dry_label)
345
  self.log(
346
  f"{mode}_loss",
347
  loss,
remfx/utils.py CHANGED
@@ -7,6 +7,10 @@ from frechet_audio_distance import FrechetAudioDistance
7
  import numpy as np
8
  import torch
9
  import torchaudio
 
 
 
 
10
 
11
 
12
  def get_logger(name=__name__) -> logging.Logger:
@@ -138,3 +142,91 @@ def create_sequential_chunks(
138
  break
139
  chunks.append(audio[:, start : start + chunk_size])
140
  return chunks, sr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  import numpy as np
8
  import torch
9
  import torchaudio
10
+ from torch import Tensor, nn
11
+ import wandb
12
+ from einops import rearrange
13
+ from torch._six import container_abcs
14
 
15
 
16
  def get_logger(name=__name__) -> logging.Logger:
 
142
  break
143
  chunks.append(audio[:, start : start + chunk_size])
144
  return chunks, sr
145
+
146
+
147
+ def log_wandb_audio_batch(
148
+ logger: pl.loggers.WandbLogger,
149
+ id: str,
150
+ samples: Tensor,
151
+ sampling_rate: int,
152
+ caption: str = "",
153
+ max_items: int = 10,
154
+ ):
155
+ num_items = samples.shape[0]
156
+ samples = rearrange(samples, "b c t -> b t c")
157
+ for idx in range(num_items):
158
+ if idx >= max_items:
159
+ break
160
+ logger.experiment.log(
161
+ {
162
+ f"{id}_{idx}": wandb.Audio(
163
+ samples[idx].cpu().numpy(),
164
+ caption=caption,
165
+ sample_rate=sampling_rate,
166
+ )
167
+ }
168
+ )
169
+
170
+
171
+ def spectrogram(
172
+ x: torch.Tensor,
173
+ window: torch.Tensor,
174
+ n_fft: int,
175
+ hop_length: int,
176
+ alpha: float,
177
+ ) -> torch.Tensor:
178
+ bs, chs, samp = x.size()
179
+ x = x.view(bs * chs, -1) # move channels onto batch dim
180
+
181
+ X = torch.stft(
182
+ x,
183
+ n_fft=n_fft,
184
+ hop_length=hop_length,
185
+ window=window,
186
+ return_complex=True,
187
+ )
188
+
189
+ # move channels back
190
+ X = X.view(bs, chs, X.shape[-2], X.shape[-1])
191
+
192
+ return torch.pow(X.abs() + 1e-8, alpha)
193
+
194
+
195
+ def init_layer(layer):
196
+ """Initialize a Linear or Convolutional layer."""
197
+ nn.init.xavier_uniform_(layer.weight)
198
+
199
+ if hasattr(layer, "bias"):
200
+ if layer.bias is not None:
201
+ layer.bias.data.fill_(0.0)
202
+
203
+
204
+ def init_bn(bn):
205
+ """Initialize a Batchnorm layer."""
206
+ bn.bias.data.fill_(0.0)
207
+ bn.weight.data.fill_(1.0)
208
+
209
+
210
+ def _ntuple(n: int):
211
+ def parse(x):
212
+ if isinstance(x, container_abcs.Iterable):
213
+ return x
214
+ return tuple([x] * n)
215
+
216
+ return parse
217
+
218
+
219
+ single = _ntuple(1)
220
+
221
+
222
+ def concat_complex(a: torch.tensor, b: torch.tensor, dim: int = 1) -> torch.tensor:
223
+ """
224
+ Concatenate two complex tensors in same dimension concept
225
+ :param a: complex tensor
226
+ :param b: another complex tensor
227
+ :param dim: target dimension
228
+ :return: concatenated tensor
229
+ """
230
+ a_real, a_img = a.chunk(2, dim)
231
+ b_real, b_img = b.chunk(2, dim)
232
+ return torch.cat([a_real, b_real, a_img, b_img], dim=dim)
scripts/test.py CHANGED
@@ -3,7 +3,6 @@ import hydra
3
  from omegaconf import DictConfig
4
  import remfx.utils as utils
5
  from pytorch_lightning.utilities.model_summary import ModelSummary
6
- from remfx.models import RemFXModel
7
  import torch
8
 
9
  log = utils.get_logger(__name__)
 
3
  from omegaconf import DictConfig
4
  import remfx.utils as utils
5
  from pytorch_lightning.utilities.model_summary import ModelSummary
 
6
  import torch
7
 
8
  log = utils.get_logger(__name__)