antonbol commited on
Commit
bc67f97
1 Parent(s): af564b2

added inference.py

Browse files
Files changed (1) hide show
  1. inference.py +180 -0
inference.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+
4
+ import librosa
5
+ import numpy as np
6
+ import soundfile as sf
7
+ import torch
8
+ from tqdm import tqdm
9
+
10
+ from lib import dataset
11
+ from lib import nets
12
+ from lib import spec_utils
13
+ from lib import utils
14
+
15
+
16
+ class Separator(object):
17
+
18
+ def __init__(self, model, device, batchsize, cropsize, postprocess=False):
19
+ self.model = model
20
+ self.offset = model.offset
21
+ self.device = device
22
+ self.batchsize = batchsize
23
+ self.cropsize = cropsize
24
+ self.postprocess = postprocess
25
+
26
+ def _separate(self, X_mag_pad, roi_size):
27
+ X_dataset = []
28
+ patches = (X_mag_pad.shape[2] - 2 * self.offset) // roi_size
29
+ for i in range(patches):
30
+ start = i * roi_size
31
+ X_mag_crop = X_mag_pad[:, :, start:start + self.cropsize]
32
+ X_dataset.append(X_mag_crop)
33
+
34
+ X_dataset = np.asarray(X_dataset)
35
+
36
+ self.model.eval()
37
+ with torch.no_grad():
38
+ mask = []
39
+ # To reduce the overhead, dataloader is not used.
40
+ for i in tqdm(range(0, patches, self.batchsize)):
41
+ X_batch = X_dataset[i: i + self.batchsize]
42
+ X_batch = torch.from_numpy(X_batch).to(self.device)
43
+
44
+ pred = self.model.predict_mask(X_batch)
45
+
46
+ pred = pred.detach().cpu().numpy()
47
+ pred = np.concatenate(pred, axis=2)
48
+ mask.append(pred)
49
+
50
+ mask = np.concatenate(mask, axis=2)
51
+
52
+ return mask
53
+
54
+ def _preprocess(self, X_spec):
55
+ X_mag = np.abs(X_spec)
56
+ X_phase = np.angle(X_spec)
57
+
58
+ return X_mag, X_phase
59
+
60
+ def _postprocess(self, mask, X_mag, X_phase):
61
+ if self.postprocess:
62
+ mask = spec_utils.merge_artifacts(mask)
63
+
64
+ y_spec = mask * X_mag * np.exp(1.j * X_phase)
65
+ v_spec = (1 - mask) * X_mag * np.exp(1.j * X_phase)
66
+
67
+ return y_spec, v_spec
68
+
69
+ def separate(self, X_spec):
70
+ X_mag, X_phase = self._preprocess(X_spec)
71
+
72
+ n_frame = X_mag.shape[2]
73
+ pad_l, pad_r, roi_size = dataset.make_padding(n_frame, self.cropsize, self.offset)
74
+ X_mag_pad = np.pad(X_mag, ((0, 0), (0, 0), (pad_l, pad_r)), mode='constant')
75
+ X_mag_pad /= X_mag_pad.max()
76
+
77
+ mask = self._separate(X_mag_pad, roi_size)
78
+ mask = mask[:, :, :n_frame]
79
+
80
+ y_spec, v_spec = self._postprocess(mask, X_mag, X_phase)
81
+
82
+ return y_spec, v_spec
83
+
84
+ def separate_tta(self, X_spec):
85
+ X_mag, X_phase = self._preprocess(X_spec)
86
+
87
+ n_frame = X_mag.shape[2]
88
+ pad_l, pad_r, roi_size = dataset.make_padding(n_frame, self.cropsize, self.offset)
89
+ X_mag_pad = np.pad(X_mag, ((0, 0), (0, 0), (pad_l, pad_r)), mode='constant')
90
+ X_mag_pad /= X_mag_pad.max()
91
+
92
+ mask = self._separate(X_mag_pad, roi_size)
93
+
94
+ pad_l += roi_size // 2
95
+ pad_r += roi_size // 2
96
+ X_mag_pad = np.pad(X_mag, ((0, 0), (0, 0), (pad_l, pad_r)), mode='constant')
97
+ X_mag_pad /= X_mag_pad.max()
98
+
99
+ mask_tta = self._separate(X_mag_pad, roi_size)
100
+ mask_tta = mask_tta[:, :, roi_size // 2:]
101
+ mask = (mask[:, :, :n_frame] + mask_tta[:, :, :n_frame]) * 0.5
102
+
103
+ y_spec, v_spec = self._postprocess(mask, X_mag, X_phase)
104
+
105
+ return y_spec, v_spec
106
+
107
+
108
+ def main():
109
+ p = argparse.ArgumentParser()
110
+ p.add_argument('--gpu', '-g', type=int, default=-1)
111
+ p.add_argument('--pretrained_model', '-P', type=str, default='models/baseline.pth')
112
+ p.add_argument('--input', '-i', required=True)
113
+ p.add_argument('--sr', '-r', type=int, default=44100)
114
+ p.add_argument('--n_fft', '-f', type=int, default=2048)
115
+ p.add_argument('--hop_length', '-H', type=int, default=1024)
116
+ p.add_argument('--batchsize', '-B', type=int, default=4)
117
+ p.add_argument('--cropsize', '-c', type=int, default=256)
118
+ p.add_argument('--output_image', '-I', action='store_true')
119
+ p.add_argument('--postprocess', '-p', action='store_true')
120
+ p.add_argument('--tta', '-t', action='store_true')
121
+ p.add_argument('--output_dir', '-o', type=str, default="")
122
+ args = p.parse_args()
123
+
124
+ print('loading model...', end=' ')
125
+ device = torch.device('cpu')
126
+ model = nets.CascadedNet(args.n_fft, 32, 128)
127
+ model.load_state_dict(torch.load(args.pretrained_model, map_location=device))
128
+ if torch.cuda.is_available() and args.gpu >= 0:
129
+ device = torch.device('cuda:{}'.format(args.gpu))
130
+ model.to(device)
131
+ print('done')
132
+
133
+ print('loading wave source...', end=' ')
134
+ X, sr = librosa.load(
135
+ args.input, args.sr, False, dtype=np.float32, res_type='kaiser_fast')
136
+ basename = os.path.splitext(os.path.basename(args.input))[0]
137
+ print('done')
138
+
139
+ if X.ndim == 1:
140
+ # mono to stereo
141
+ X = np.asarray([X, X])
142
+
143
+ print('stft of wave source...', end=' ')
144
+ X_spec = spec_utils.wave_to_spectrogram(X, args.hop_length, args.n_fft)
145
+ print('done')
146
+
147
+ sp = Separator(model, device, args.batchsize, args.cropsize, args.postprocess)
148
+
149
+ if args.tta:
150
+ y_spec, v_spec = sp.separate_tta(X_spec)
151
+ else:
152
+ y_spec, v_spec = sp.separate(X_spec)
153
+
154
+ print('validating output directory...', end=' ')
155
+ output_dir = args.output_dir
156
+ if output_dir != "": # modifies output_dir if theres an arg specified
157
+ output_dir = output_dir.rstrip('/') + '/'
158
+ os.makedirs(output_dir, exist_ok=True)
159
+ print('done')
160
+
161
+ print('inverse stft of instruments...', end=' ')
162
+ wave = spec_utils.spectrogram_to_wave(y_spec, hop_length=args.hop_length)
163
+ print('done')
164
+ sf.write('{}{}_Instruments.wav'.format(output_dir, basename), wave.T, sr)
165
+
166
+ print('inverse stft of vocals...', end=' ')
167
+ wave = spec_utils.spectrogram_to_wave(v_spec, hop_length=args.hop_length)
168
+ print('done')
169
+ sf.write('{}{}_Vocals.wav'.format(output_dir, basename), wave.T, sr)
170
+
171
+ if args.output_image:
172
+ image = spec_utils.spectrogram_to_image(y_spec)
173
+ utils.imwrite('{}{}_Instruments.jpg'.format(output_dir, basename), image)
174
+
175
+ image = spec_utils.spectrogram_to_image(v_spec)
176
+ utils.imwrite('{}{}_Vocals.jpg'.format(output_dir, basename), image)
177
+
178
+
179
+ if __name__ == '__main__':
180
+ main()