antonbol commited on
Commit
af564b2
1 Parent(s): a7804b2
lib/__init__.py ADDED
File without changes
lib/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (156 Bytes). View file
 
lib/__pycache__/dataset.cpython-38.pyc ADDED
Binary file (7.19 kB). View file
 
lib/__pycache__/layers.cpython-38.pyc ADDED
Binary file (4.48 kB). View file
 
lib/__pycache__/nets.cpython-38.pyc ADDED
Binary file (3.8 kB). View file
 
lib/__pycache__/spec_utils.cpython-38.pyc ADDED
Binary file (5.33 kB). View file
 
lib/__pycache__/utils.cpython-38.pyc ADDED
Binary file (887 Bytes). View file
 
lib/dataset.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.utils.data
7
+ from tqdm import tqdm
8
+
9
+ try:
10
+ from lib import spec_utils
11
+ except ModuleNotFoundError:
12
+ import spec_utils
13
+
14
+
15
+ class VocalRemoverTrainingSet(torch.utils.data.Dataset):
16
+
17
+ def __init__(self, training_set, cropsize, reduction_rate, reduction_weight, mixup_rate, mixup_alpha):
18
+ self.training_set = training_set
19
+ self.cropsize = cropsize
20
+ self.reduction_rate = reduction_rate
21
+ self.reduction_weight = reduction_weight
22
+ self.mixup_rate = mixup_rate
23
+ self.mixup_alpha = mixup_alpha
24
+
25
+ def __len__(self):
26
+ return len(self.training_set)
27
+
28
+ def do_crop(self, X_path, y_path):
29
+ X_mmap = np.load(X_path, mmap_mode='r')
30
+ y_mmap = np.load(y_path, mmap_mode='r')
31
+
32
+ start = np.random.randint(0, X_mmap.shape[2] - self.cropsize)
33
+ end = start + self.cropsize
34
+
35
+ X_crop = np.array(X_mmap[:, :, start:end], copy=True)
36
+ y_crop = np.array(y_mmap[:, :, start:end], copy=True)
37
+
38
+ return X_crop, y_crop
39
+
40
+ def do_aug(self, X, y):
41
+ if np.random.uniform() < self.reduction_rate:
42
+ y = spec_utils.aggressively_remove_vocal(X, y, self.reduction_weight)
43
+
44
+ if np.random.uniform() < 0.5:
45
+ # swap channel
46
+ X = X[::-1].copy()
47
+ y = y[::-1].copy()
48
+
49
+ if np.random.uniform() < 0.01:
50
+ # inst
51
+ X = y.copy()
52
+
53
+ # if np.random.uniform() < 0.01:
54
+ # # mono
55
+ # X[:] = X.mean(axis=0, keepdims=True)
56
+ # y[:] = y.mean(axis=0, keepdims=True)
57
+
58
+ return X, y
59
+
60
+ def do_mixup(self, X, y):
61
+ idx = np.random.randint(0, len(self))
62
+ X_path, y_path, coef = self.training_set[idx]
63
+
64
+ X_i, y_i = self.do_crop(X_path, y_path)
65
+ X_i /= coef
66
+ y_i /= coef
67
+
68
+ X_i, y_i = self.do_aug(X_i, y_i)
69
+
70
+ lam = np.random.beta(self.mixup_alpha, self.mixup_alpha)
71
+ X = lam * X + (1 - lam) * X_i
72
+ y = lam * y + (1 - lam) * y_i
73
+
74
+ return X, y
75
+
76
+ def __getitem__(self, idx):
77
+ X_path, y_path, coef = self.training_set[idx]
78
+
79
+ X, y = self.do_crop(X_path, y_path)
80
+ X /= coef
81
+ y /= coef
82
+
83
+ X, y = self.do_aug(X, y)
84
+
85
+ if np.random.uniform() < self.mixup_rate:
86
+ X, y = self.do_mixup(X, y)
87
+
88
+ X_mag = np.abs(X)
89
+ y_mag = np.abs(y)
90
+
91
+ return X_mag, y_mag
92
+
93
+
94
+ class VocalRemoverValidationSet(torch.utils.data.Dataset):
95
+
96
+ def __init__(self, patch_list):
97
+ self.patch_list = patch_list
98
+
99
+ def __len__(self):
100
+ return len(self.patch_list)
101
+
102
+ def __getitem__(self, idx):
103
+ path = self.patch_list[idx]
104
+ data = np.load(path)
105
+
106
+ X, y = data['X'], data['y']
107
+
108
+ X_mag = np.abs(X)
109
+ y_mag = np.abs(y)
110
+
111
+ return X_mag, y_mag
112
+
113
+
114
+ def make_pair(mix_dir, inst_dir):
115
+ input_exts = ['.wav', '.m4a', '.mp3', '.mp4', '.flac']
116
+
117
+ X_list = sorted([
118
+ os.path.join(mix_dir, fname)
119
+ for fname in os.listdir(mix_dir)
120
+ if os.path.splitext(fname)[1] in input_exts
121
+ ])
122
+ y_list = sorted([
123
+ os.path.join(inst_dir, fname)
124
+ for fname in os.listdir(inst_dir)
125
+ if os.path.splitext(fname)[1] in input_exts
126
+ ])
127
+
128
+ filelist = list(zip(X_list, y_list))
129
+
130
+ return filelist
131
+
132
+
133
+ def train_val_split(dataset_dir, split_mode, val_rate, val_filelist):
134
+ if split_mode == 'random':
135
+ filelist = make_pair(
136
+ os.path.join(dataset_dir, 'mixtures'),
137
+ os.path.join(dataset_dir, 'instruments')
138
+ )
139
+
140
+ random.shuffle(filelist)
141
+
142
+ if len(val_filelist) == 0:
143
+ val_size = int(len(filelist) * val_rate)
144
+ train_filelist = filelist[:-val_size]
145
+ val_filelist = filelist[-val_size:]
146
+ else:
147
+ train_filelist = [
148
+ pair for pair in filelist
149
+ if list(pair) not in val_filelist
150
+ ]
151
+ elif split_mode == 'subdirs':
152
+ if len(val_filelist) != 0:
153
+ raise ValueError('`val_filelist` option is not available with `subdirs` mode')
154
+
155
+ train_filelist = make_pair(
156
+ os.path.join(dataset_dir, 'training/mixtures'),
157
+ os.path.join(dataset_dir, 'training/instruments')
158
+ )
159
+
160
+ val_filelist = make_pair(
161
+ os.path.join(dataset_dir, 'validation/mixtures'),
162
+ os.path.join(dataset_dir, 'validation/instruments')
163
+ )
164
+
165
+ return train_filelist, val_filelist
166
+
167
+
168
+ def make_padding(width, cropsize, offset):
169
+ left = offset
170
+ roi_size = cropsize - offset * 2
171
+ if roi_size == 0:
172
+ roi_size = cropsize
173
+ right = roi_size - (width % roi_size) + left
174
+
175
+ return left, right, roi_size
176
+
177
+
178
+ def make_training_set(filelist, sr, hop_length, n_fft):
179
+ ret = []
180
+ for X_path, y_path in tqdm(filelist):
181
+ X, y, X_cache_path, y_cache_path = spec_utils.cache_or_load(
182
+ X_path, y_path, sr, hop_length, n_fft
183
+ )
184
+ coef = np.max([np.abs(X).max(), np.abs(y).max()])
185
+ ret.append([X_cache_path, y_cache_path, coef])
186
+
187
+ return ret
188
+
189
+
190
+ def make_validation_set(filelist, cropsize, sr, hop_length, n_fft, offset):
191
+ patch_list = []
192
+ patch_dir = 'cs{}_sr{}_hl{}_nf{}_of{}'.format(cropsize, sr, hop_length, n_fft, offset)
193
+ os.makedirs(patch_dir, exist_ok=True)
194
+
195
+ for X_path, y_path in tqdm(filelist):
196
+ basename = os.path.splitext(os.path.basename(X_path))[0]
197
+
198
+ X, y, _, _ = spec_utils.cache_or_load(X_path, y_path, sr, hop_length, n_fft)
199
+ coef = np.max([np.abs(X).max(), np.abs(y).max()])
200
+ X, y = X / coef, y / coef
201
+
202
+ l, r, roi_size = make_padding(X.shape[2], cropsize, offset)
203
+ X_pad = np.pad(X, ((0, 0), (0, 0), (l, r)), mode='constant')
204
+ y_pad = np.pad(y, ((0, 0), (0, 0), (l, r)), mode='constant')
205
+
206
+ len_dataset = int(np.ceil(X.shape[2] / roi_size))
207
+ for j in range(len_dataset):
208
+ outpath = os.path.join(patch_dir, '{}_p{}.npz'.format(basename, j))
209
+ start = j * roi_size
210
+ if not os.path.exists(outpath):
211
+ np.savez(
212
+ outpath,
213
+ X=X_pad[:, :, start:start + cropsize],
214
+ y=y_pad[:, :, start:start + cropsize]
215
+ )
216
+ patch_list.append(outpath)
217
+
218
+ return patch_list
219
+
220
+
221
+ def get_oracle_data(X, y, oracle_loss, oracle_rate, oracle_drop_rate):
222
+ k = int(len(X) * oracle_rate * (1 / (1 - oracle_drop_rate)))
223
+ n = int(len(X) * oracle_rate)
224
+ indices = np.argsort(oracle_loss)[::-1][:k]
225
+ indices = np.random.choice(indices, n, replace=False)
226
+ oracle_X = X[indices].copy()
227
+ oracle_y = y[indices].copy()
228
+
229
+ return oracle_X, oracle_y, indices
230
+
231
+
232
+ if __name__ == "__main__":
233
+ import sys
234
+ import utils
235
+
236
+ mix_dir = sys.argv[1]
237
+ inst_dir = sys.argv[2]
238
+ outdir = sys.argv[3]
239
+
240
+ os.makedirs(outdir, exist_ok=True)
241
+
242
+ filelist = make_pair(mix_dir, inst_dir)
243
+ for mix_path, inst_path in tqdm(filelist):
244
+ mix_basename = os.path.splitext(os.path.basename(mix_path))[0]
245
+
246
+ X_spec, y_spec, _, _ = spec_utils.cache_or_load(
247
+ mix_path, inst_path, 44100, 1024, 2048
248
+ )
249
+
250
+ X_mag = np.abs(X_spec)
251
+ y_mag = np.abs(y_spec)
252
+ v_mag = X_mag - y_mag
253
+ v_mag *= v_mag > y_mag
254
+
255
+ outpath = '{}/{}_Vocal.jpg'.format(outdir, mix_basename)
256
+ v_image = spec_utils.spectrogram_to_image(v_mag)
257
+ utils.imwrite(outpath, v_image)
lib/layers.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+
5
+ from lib import spec_utils
6
+
7
+
8
+ class Conv2DBNActiv(nn.Module):
9
+
10
+ def __init__(self, nin, nout, ksize=3, stride=1, pad=1, dilation=1, activ=nn.ReLU):
11
+ super(Conv2DBNActiv, self).__init__()
12
+ self.conv = nn.Sequential(
13
+ nn.Conv2d(
14
+ nin, nout,
15
+ kernel_size=ksize,
16
+ stride=stride,
17
+ padding=pad,
18
+ dilation=dilation,
19
+ bias=False
20
+ ),
21
+ nn.BatchNorm2d(nout),
22
+ activ()
23
+ )
24
+
25
+ def __call__(self, x):
26
+ return self.conv(x)
27
+
28
+
29
+ # class SeperableConv2DBNActiv(nn.Module):
30
+
31
+ # def __init__(self, nin, nout, ksize=3, stride=1, pad=1, dilation=1, activ=nn.ReLU):
32
+ # super(SeperableConv2DBNActiv, self).__init__()
33
+ # self.conv = nn.Sequential(
34
+ # nn.Conv2d(
35
+ # nin, nin,
36
+ # kernel_size=ksize,
37
+ # stride=stride,
38
+ # padding=pad,
39
+ # dilation=dilation,
40
+ # groups=nin,
41
+ # bias=False
42
+ # ),
43
+ # nn.Conv2d(
44
+ # nin, nout,
45
+ # kernel_size=1,
46
+ # bias=False
47
+ # ),
48
+ # nn.BatchNorm2d(nout),
49
+ # activ()
50
+ # )
51
+
52
+ # def __call__(self, x):
53
+ # return self.conv(x)
54
+
55
+
56
+ class Encoder(nn.Module):
57
+
58
+ def __init__(self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.LeakyReLU):
59
+ super(Encoder, self).__init__()
60
+ self.conv1 = Conv2DBNActiv(nin, nout, ksize, stride, pad, activ=activ)
61
+ self.conv2 = Conv2DBNActiv(nout, nout, ksize, 1, pad, activ=activ)
62
+
63
+ def __call__(self, x):
64
+ h = self.conv1(x)
65
+ h = self.conv2(h)
66
+
67
+ return h
68
+
69
+
70
+ class Decoder(nn.Module):
71
+
72
+ def __init__(self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.ReLU, dropout=False):
73
+ super(Decoder, self).__init__()
74
+ self.conv1 = Conv2DBNActiv(nin, nout, ksize, 1, pad, activ=activ)
75
+ # self.conv2 = Conv2DBNActiv(nout, nout, ksize, 1, pad, activ=activ)
76
+ self.dropout = nn.Dropout2d(0.1) if dropout else None
77
+
78
+ def __call__(self, x, skip=None):
79
+ x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)
80
+
81
+ if skip is not None:
82
+ skip = spec_utils.crop_center(skip, x)
83
+ x = torch.cat([x, skip], dim=1)
84
+
85
+ h = self.conv1(x)
86
+ # h = self.conv2(h)
87
+
88
+ if self.dropout is not None:
89
+ h = self.dropout(h)
90
+
91
+ return h
92
+
93
+
94
+ class ASPPModule(nn.Module):
95
+
96
+ def __init__(self, nin, nout, dilations=(4, 8, 12), activ=nn.ReLU, dropout=False):
97
+ super(ASPPModule, self).__init__()
98
+ self.conv1 = nn.Sequential(
99
+ nn.AdaptiveAvgPool2d((1, None)),
100
+ Conv2DBNActiv(nin, nout, 1, 1, 0, activ=activ)
101
+ )
102
+ self.conv2 = Conv2DBNActiv(
103
+ nin, nout, 1, 1, 0, activ=activ
104
+ )
105
+ self.conv3 = Conv2DBNActiv(
106
+ nin, nout, 3, 1, dilations[0], dilations[0], activ=activ
107
+ )
108
+ self.conv4 = Conv2DBNActiv(
109
+ nin, nout, 3, 1, dilations[1], dilations[1], activ=activ
110
+ )
111
+ self.conv5 = Conv2DBNActiv(
112
+ nin, nout, 3, 1, dilations[2], dilations[2], activ=activ
113
+ )
114
+ self.bottleneck = Conv2DBNActiv(
115
+ nout * 5, nout, 1, 1, 0, activ=activ
116
+ )
117
+ self.dropout = nn.Dropout2d(0.1) if dropout else None
118
+
119
+ def forward(self, x):
120
+ _, _, h, w = x.size()
121
+ feat1 = F.interpolate(self.conv1(x), size=(h, w), mode='bilinear', align_corners=True)
122
+ feat2 = self.conv2(x)
123
+ feat3 = self.conv3(x)
124
+ feat4 = self.conv4(x)
125
+ feat5 = self.conv5(x)
126
+ out = torch.cat((feat1, feat2, feat3, feat4, feat5), dim=1)
127
+ out = self.bottleneck(out)
128
+
129
+ if self.dropout is not None:
130
+ out = self.dropout(out)
131
+
132
+ return out
133
+
134
+
135
+ class LSTMModule(nn.Module):
136
+
137
+ def __init__(self, nin_conv, nin_lstm, nout_lstm):
138
+ super(LSTMModule, self).__init__()
139
+ self.conv = Conv2DBNActiv(nin_conv, 1, 1, 1, 0)
140
+ self.lstm = nn.LSTM(
141
+ input_size=nin_lstm,
142
+ hidden_size=nout_lstm // 2,
143
+ bidirectional=True
144
+ )
145
+ self.dense = nn.Sequential(
146
+ nn.Linear(nout_lstm, nin_lstm),
147
+ nn.BatchNorm1d(nin_lstm),
148
+ nn.ReLU()
149
+ )
150
+
151
+ def forward(self, x):
152
+ N, _, nbins, nframes = x.size()
153
+ h = self.conv(x)[:, 0] # N, nbins, nframes
154
+ h = h.permute(2, 0, 1) # nframes, N, nbins
155
+ h, _ = self.lstm(h)
156
+ h = self.dense(h.reshape(-1, h.size()[-1])) # nframes * N, nbins
157
+ h = h.reshape(nframes, N, 1, nbins)
158
+ h = h.permute(1, 2, 3, 0)
159
+
160
+ return h
lib/nets.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+
5
+ from lib import layers
6
+
7
+
8
+ class BaseNet(nn.Module):
9
+
10
+ def __init__(self, nin, nout, nin_lstm, nout_lstm, dilations=((4, 2), (8, 4), (12, 6))):
11
+ super(BaseNet, self).__init__()
12
+ self.enc1 = layers.Conv2DBNActiv(nin, nout, 3, 1, 1)
13
+ self.enc2 = layers.Encoder(nout, nout * 2, 3, 2, 1)
14
+ self.enc3 = layers.Encoder(nout * 2, nout * 4, 3, 2, 1)
15
+ self.enc4 = layers.Encoder(nout * 4, nout * 6, 3, 2, 1)
16
+ self.enc5 = layers.Encoder(nout * 6, nout * 8, 3, 2, 1)
17
+
18
+ self.aspp = layers.ASPPModule(nout * 8, nout * 8, dilations, dropout=True)
19
+
20
+ self.dec4 = layers.Decoder(nout * (6 + 8), nout * 6, 3, 1, 1)
21
+ self.dec3 = layers.Decoder(nout * (4 + 6), nout * 4, 3, 1, 1)
22
+ self.dec2 = layers.Decoder(nout * (2 + 4), nout * 2, 3, 1, 1)
23
+ self.lstm_dec2 = layers.LSTMModule(nout * 2, nin_lstm, nout_lstm)
24
+ self.dec1 = layers.Decoder(nout * (1 + 2) + 1, nout * 1, 3, 1, 1)
25
+
26
+ def __call__(self, x):
27
+ e1 = self.enc1(x)
28
+ e2 = self.enc2(e1)
29
+ e3 = self.enc3(e2)
30
+ e4 = self.enc4(e3)
31
+ e5 = self.enc5(e4)
32
+
33
+ h = self.aspp(e5)
34
+
35
+ h = self.dec4(h, e4)
36
+ h = self.dec3(h, e3)
37
+ h = self.dec2(h, e2)
38
+ h = torch.cat([h, self.lstm_dec2(h)], dim=1)
39
+ h = self.dec1(h, e1)
40
+
41
+ return h
42
+
43
+
44
+ class CascadedNet(nn.Module):
45
+
46
+ def __init__(self, n_fft, nout=32, nout_lstm=128):
47
+ super(CascadedNet, self).__init__()
48
+ self.max_bin = n_fft // 2
49
+ self.output_bin = n_fft // 2 + 1
50
+ self.nin_lstm = self.max_bin // 2
51
+ self.offset = 64
52
+
53
+ self.stg1_low_band_net = nn.Sequential(
54
+ BaseNet(2, nout // 2, self.nin_lstm // 2, nout_lstm),
55
+ layers.Conv2DBNActiv(nout // 2, nout // 4, 1, 1, 0)
56
+ )
57
+ self.stg1_high_band_net = BaseNet(
58
+ 2, nout // 4, self.nin_lstm // 2, nout_lstm // 2
59
+ )
60
+
61
+ self.stg2_low_band_net = nn.Sequential(
62
+ BaseNet(nout // 4 + 2, nout, self.nin_lstm // 2, nout_lstm),
63
+ layers.Conv2DBNActiv(nout, nout // 2, 1, 1, 0)
64
+ )
65
+ self.stg2_high_band_net = BaseNet(
66
+ nout // 4 + 2, nout // 2, self.nin_lstm // 2, nout_lstm // 2
67
+ )
68
+
69
+ self.stg3_full_band_net = BaseNet(
70
+ 3 * nout // 4 + 2, nout, self.nin_lstm, nout_lstm
71
+ )
72
+
73
+ self.out = nn.Conv2d(nout, 2, 1, bias=False)
74
+ self.aux_out = nn.Conv2d(3 * nout // 4, 2, 1, bias=False)
75
+
76
+ def forward(self, x):
77
+ x = x[:, :, :self.max_bin]
78
+
79
+ bandw = x.size()[2] // 2
80
+ l1_in = x[:, :, :bandw]
81
+ h1_in = x[:, :, bandw:]
82
+ l1 = self.stg1_low_band_net(l1_in)
83
+ h1 = self.stg1_high_band_net(h1_in)
84
+ aux1 = torch.cat([l1, h1], dim=2)
85
+
86
+ l2_in = torch.cat([l1_in, l1], dim=1)
87
+ h2_in = torch.cat([h1_in, h1], dim=1)
88
+ l2 = self.stg2_low_band_net(l2_in)
89
+ h2 = self.stg2_high_band_net(h2_in)
90
+ aux2 = torch.cat([l2, h2], dim=2)
91
+
92
+ f3_in = torch.cat([x, aux1, aux2], dim=1)
93
+ f3 = self.stg3_full_band_net(f3_in)
94
+
95
+ mask = torch.sigmoid(self.out(f3))
96
+ mask = F.pad(
97
+ input=mask,
98
+ pad=(0, 0, 0, self.output_bin - mask.size()[2]),
99
+ mode='replicate'
100
+ )
101
+
102
+ if self.training:
103
+ aux = torch.cat([aux1, aux2], dim=1)
104
+ aux = torch.sigmoid(self.aux_out(aux))
105
+ aux = F.pad(
106
+ input=aux,
107
+ pad=(0, 0, 0, self.output_bin - aux.size()[2]),
108
+ mode='replicate'
109
+ )
110
+ return mask, aux
111
+ else:
112
+ return mask
113
+
114
+ def predict_mask(self, x):
115
+ mask = self.forward(x)
116
+
117
+ if self.offset > 0:
118
+ mask = mask[:, :, :, self.offset:-self.offset]
119
+ assert mask.size()[3] > 0
120
+
121
+ return mask
122
+
123
+ def predict(self, x):
124
+ mask = self.forward(x)
125
+ pred_mag = x * mask
126
+
127
+ if self.offset > 0:
128
+ pred_mag = pred_mag[:, :, :, self.offset:-self.offset]
129
+ assert pred_mag.size()[3] > 0
130
+
131
+ return pred_mag
lib/spec_utils.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import librosa
4
+ import numpy as np
5
+ import soundfile as sf
6
+
7
+
8
+ def crop_center(h1, h2):
9
+ h1_shape = h1.size()
10
+ h2_shape = h2.size()
11
+
12
+ if h1_shape[3] == h2_shape[3]:
13
+ return h1
14
+ elif h1_shape[3] < h2_shape[3]:
15
+ raise ValueError('h1_shape[3] must be greater than h2_shape[3]')
16
+
17
+ # s_freq = (h2_shape[2] - h1_shape[2]) // 2
18
+ # e_freq = s_freq + h1_shape[2]
19
+ s_time = (h1_shape[3] - h2_shape[3]) // 2
20
+ e_time = s_time + h2_shape[3]
21
+ h1 = h1[:, :, :, s_time:e_time]
22
+
23
+ return h1
24
+
25
+
26
+ def wave_to_spectrogram(wave, hop_length, n_fft):
27
+ wave_left = np.asfortranarray(wave[0])
28
+ wave_right = np.asfortranarray(wave[1])
29
+
30
+ spec_left = librosa.stft(wave_left, n_fft, hop_length=hop_length)
31
+ spec_right = librosa.stft(wave_right, n_fft, hop_length=hop_length)
32
+ spec = np.asfortranarray([spec_left, spec_right])
33
+
34
+ return spec
35
+
36
+
37
+ def spectrogram_to_image(spec, mode='magnitude'):
38
+ if mode == 'magnitude':
39
+ if np.iscomplexobj(spec):
40
+ y = np.abs(spec)
41
+ else:
42
+ y = spec
43
+ y = np.log10(y ** 2 + 1e-8)
44
+ elif mode == 'phase':
45
+ if np.iscomplexobj(spec):
46
+ y = np.angle(spec)
47
+ else:
48
+ y = spec
49
+
50
+ y -= y.min()
51
+ y *= 255 / y.max()
52
+ img = np.uint8(y)
53
+
54
+ if y.ndim == 3:
55
+ img = img.transpose(1, 2, 0)
56
+ img = np.concatenate([
57
+ np.max(img, axis=2, keepdims=True), img
58
+ ], axis=2)
59
+
60
+ return img
61
+
62
+
63
+ def aggressively_remove_vocal(X, y, weight):
64
+ X_mag = np.abs(X)
65
+ y_mag = np.abs(y)
66
+ # v_mag = np.abs(X_mag - y_mag)
67
+ v_mag = X_mag - y_mag
68
+ v_mag *= v_mag > y_mag
69
+
70
+ y_mag = np.clip(y_mag - v_mag * weight, 0, np.inf)
71
+
72
+ return y_mag * np.exp(1.j * np.angle(y))
73
+
74
+
75
+ def merge_artifacts(y_mask, thres=0.05, min_range=64, fade_size=32):
76
+ if min_range < fade_size * 2:
77
+ raise ValueError('min_range must be >= fade_size * 2')
78
+
79
+ idx = np.where(y_mask.min(axis=(0, 1)) > thres)[0]
80
+ start_idx = np.insert(idx[np.where(np.diff(idx) != 1)[0] + 1], 0, idx[0])
81
+ end_idx = np.append(idx[np.where(np.diff(idx) != 1)[0]], idx[-1])
82
+ artifact_idx = np.where(end_idx - start_idx > min_range)[0]
83
+ weight = np.zeros_like(y_mask)
84
+ if len(artifact_idx) > 0:
85
+ start_idx = start_idx[artifact_idx]
86
+ end_idx = end_idx[artifact_idx]
87
+ old_e = None
88
+ for s, e in zip(start_idx, end_idx):
89
+ if old_e is not None and s - old_e < fade_size:
90
+ s = old_e - fade_size * 2
91
+
92
+ if s != 0:
93
+ weight[:, :, s:s + fade_size] = np.linspace(0, 1, fade_size)
94
+ else:
95
+ s -= fade_size
96
+
97
+ if e != y_mask.shape[2]:
98
+ weight[:, :, e - fade_size:e] = np.linspace(1, 0, fade_size)
99
+ else:
100
+ e += fade_size
101
+
102
+ weight[:, :, s + fade_size:e - fade_size] = 1
103
+ old_e = e
104
+
105
+ v_mask = 1 - y_mask
106
+ y_mask += weight * v_mask
107
+
108
+ return y_mask
109
+
110
+
111
+ def align_wave_head_and_tail(a, b, sr):
112
+ a, _ = librosa.effects.trim(a)
113
+ b, _ = librosa.effects.trim(b)
114
+
115
+ a_mono = a[:, :sr * 4].sum(axis=0)
116
+ b_mono = b[:, :sr * 4].sum(axis=0)
117
+
118
+ a_mono -= a_mono.mean()
119
+ b_mono -= b_mono.mean()
120
+
121
+ offset = len(a_mono) - 1
122
+ delay = np.argmax(np.correlate(a_mono, b_mono, 'full')) - offset
123
+
124
+ if delay > 0:
125
+ a = a[:, delay:]
126
+ else:
127
+ b = b[:, np.abs(delay):]
128
+
129
+ if a.shape[1] < b.shape[1]:
130
+ b = b[:, :a.shape[1]]
131
+ else:
132
+ a = a[:, :b.shape[1]]
133
+
134
+ return a, b
135
+
136
+
137
+ def cache_or_load(mix_path, inst_path, sr, hop_length, n_fft):
138
+ mix_basename = os.path.splitext(os.path.basename(mix_path))[0]
139
+ inst_basename = os.path.splitext(os.path.basename(inst_path))[0]
140
+
141
+ cache_dir = 'sr{}_hl{}_nf{}'.format(sr, hop_length, n_fft)
142
+ mix_cache_dir = os.path.join(os.path.dirname(mix_path), cache_dir)
143
+ inst_cache_dir = os.path.join(os.path.dirname(inst_path), cache_dir)
144
+ os.makedirs(mix_cache_dir, exist_ok=True)
145
+ os.makedirs(inst_cache_dir, exist_ok=True)
146
+
147
+ mix_cache_path = os.path.join(mix_cache_dir, mix_basename + '.npy')
148
+ inst_cache_path = os.path.join(inst_cache_dir, inst_basename + '.npy')
149
+
150
+ if os.path.exists(mix_cache_path) and os.path.exists(inst_cache_path):
151
+ X = np.load(mix_cache_path)
152
+ y = np.load(inst_cache_path)
153
+ else:
154
+ X, _ = librosa.load(
155
+ mix_path, sr, False, dtype=np.float32, res_type='kaiser_fast')
156
+ y, _ = librosa.load(
157
+ inst_path, sr, False, dtype=np.float32, res_type='kaiser_fast')
158
+
159
+ X, y = align_wave_head_and_tail(X, y, sr)
160
+
161
+ X = wave_to_spectrogram(X, hop_length, n_fft)
162
+ y = wave_to_spectrogram(y, hop_length, n_fft)
163
+
164
+ np.save(mix_cache_path, X)
165
+ np.save(inst_cache_path, y)
166
+
167
+ return X, y, mix_cache_path, inst_cache_path
168
+
169
+
170
+ def spectrogram_to_wave(spec, hop_length=1024):
171
+ if spec.ndim == 2:
172
+ wave = librosa.istft(spec, hop_length=hop_length)
173
+ elif spec.ndim == 3:
174
+ spec_left = np.asfortranarray(spec[0])
175
+ spec_right = np.asfortranarray(spec[1])
176
+
177
+ wave_left = librosa.istft(spec_left, hop_length=hop_length)
178
+ wave_right = librosa.istft(spec_right, hop_length=hop_length)
179
+ wave = np.asfortranarray([wave_left, wave_right])
180
+
181
+ return wave
182
+
183
+
184
+ if __name__ == "__main__":
185
+ import cv2
186
+ import sys
187
+
188
+ bins = 2048 // 2 + 1
189
+ freq_to_bin = 2 * bins / 44100
190
+ unstable_bins = int(200 * freq_to_bin)
191
+ stable_bins = int(22050 * freq_to_bin)
192
+ reduction_weight = np.concatenate([
193
+ np.linspace(0, 1, unstable_bins, dtype=np.float32)[:, None],
194
+ np.linspace(1, 0, stable_bins - unstable_bins, dtype=np.float32)[:, None],
195
+ np.zeros((bins - stable_bins, 1))
196
+ ], axis=0) * 0.2
197
+
198
+ X, _ = librosa.load(
199
+ sys.argv[1], 44100, False, dtype=np.float32, res_type='kaiser_fast')
200
+ y, _ = librosa.load(
201
+ sys.argv[2], 44100, False, dtype=np.float32, res_type='kaiser_fast')
202
+
203
+ X, y = align_wave_head_and_tail(X, y, 44100)
204
+ X_spec = wave_to_spectrogram(X, 1024, 2048)
205
+ y_spec = wave_to_spectrogram(y, 1024, 2048)
206
+
207
+ X_mag = np.abs(X_spec)
208
+ y_mag = np.abs(y_spec)
209
+ # v_mag = np.abs(X_mag - y_mag)
210
+ v_mag = X_mag - y_mag
211
+ v_mag *= v_mag > y_mag
212
+
213
+ # y_mag = np.clip(y_mag - v_mag * reduction_weight, 0, np.inf)
214
+ y_spec = y_mag * np.exp(1j * np.angle(y_spec))
215
+ v_spec = v_mag * np.exp(1j * np.angle(X_spec))
216
+
217
+ X_image = spectrogram_to_image(X_mag)
218
+ y_image = spectrogram_to_image(y_mag)
219
+ v_image = spectrogram_to_image(v_mag)
220
+
221
+ cv2.imwrite('test_X.jpg', X_image)
222
+ cv2.imwrite('test_y.jpg', y_image)
223
+ cv2.imwrite('test_v.jpg', v_image)
224
+
225
+ sf.write('test_X.wav', spectrogram_to_wave(X_spec).T, 44100)
226
+ sf.write('test_y.wav', spectrogram_to_wave(y_spec).T, 44100)
227
+ sf.write('test_v.wav', spectrogram_to_wave(v_spec).T, 44100)
lib/utils.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import cv2
4
+ import numpy as np
5
+
6
+
7
+ def imread(filename, flags=cv2.IMREAD_COLOR, dtype=np.uint8):
8
+ try:
9
+ n = np.fromfile(filename, dtype)
10
+ img = cv2.imdecode(n, flags)
11
+ return img
12
+ except Exception as e:
13
+ print(e)
14
+ return None
15
+
16
+
17
+ def imwrite(filename, img, params=None):
18
+ try:
19
+ ext = os.path.splitext(filename)[1]
20
+ result, n = cv2.imencode(ext, img, params)
21
+
22
+ if result:
23
+ with open(filename, mode='w+b') as f:
24
+ n.tofile(f)
25
+ return True
26
+ else:
27
+ return False
28
+ except Exception as e:
29
+ print(e)
30
+ return False