victan commited on
Commit
cc83e80
1 Parent(s): 97681f9

Upload pseudo.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. pseudo.py +78 -0
pseudo.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+
4
+ import librosa
5
+ import numpy as np
6
+ import soundfile as sf
7
+ import torch
8
+
9
+ from lib import dataset
10
+ from lib import nets
11
+ from lib import spec_utils
12
+
13
+ import inference
14
+
15
+
16
+ def main():
17
+ p = argparse.ArgumentParser()
18
+ p.add_argument('--gpu', '-g', type=int, default=-1)
19
+ p.add_argument('--pretrained_model', '-P', type=str, default='models/baseline.pth')
20
+ p.add_argument('--mixtures', '-m', required=True)
21
+ p.add_argument('--instruments', '-i', required=True)
22
+ p.add_argument('--sr', '-r', type=int, default=44100)
23
+ p.add_argument('--n_fft', '-f', type=int, default=2048)
24
+ p.add_argument('--hop_length', '-H', type=int, default=1024)
25
+ p.add_argument('--batchsize', '-B', type=int, default=4)
26
+ p.add_argument('--cropsize', '-c', type=int, default=256)
27
+ p.add_argument('--postprocess', '-p', action='store_true')
28
+ args = p.parse_args()
29
+
30
+ print('loading model...', end=' ')
31
+ device = torch.device('cpu')
32
+ model = nets.CascadedNet(args.n_fft, args.hop_length)
33
+ model.load_state_dict(torch.load(args.pretrained_model, map_location=device))
34
+ if torch.cuda.is_available() and args.gpu >= 0:
35
+ device = torch.device('cuda:{}'.format(args.gpu))
36
+ model.to(device)
37
+ print('done')
38
+
39
+ filelist = dataset.make_pair(args.mixtures, args.instruments)
40
+ for mix_path, inst_path in filelist:
41
+ # if '_mixture' in mix_path and '_inst' in inst_path:
42
+ # continue
43
+ # else:
44
+ # pass
45
+
46
+ basename = os.path.splitext(os.path.basename(mix_path))[0]
47
+ print(basename)
48
+
49
+ print('loading wave source...', end=' ')
50
+ X, sr = librosa.load(
51
+ mix_path, sr=args.sr, mono=False, dtype=np.float32, res_type='kaiser_fast')
52
+ y, sr = librosa.load(
53
+ inst_path, sr=args.sr, mono=False, dtype=np.float32, res_type='kaiser_fast')
54
+ print('done')
55
+
56
+ if X.ndim == 1:
57
+ # mono to stereo
58
+ X = np.asarray([X, X])
59
+
60
+ print('stft of wave source...', end=' ')
61
+ X, y = spec_utils.align_wave_head_and_tail(X, y, sr)
62
+ X = spec_utils.wave_to_spectrogram(X, args.hop_length, args.n_fft)
63
+ y = spec_utils.wave_to_spectrogram(y, args.hop_length, args.n_fft)
64
+ print('done')
65
+
66
+ sp = inference.Separator(model, device, args.batchsize, args.cropsize, args.postprocess)
67
+ a_spec, _ = sp.separate_tta(X - y)
68
+
69
+ print('inverse stft of pseudo instruments...', end=' ')
70
+ pseudo_inst = y + a_spec
71
+ print('done')
72
+
73
+ sf.write('pseudo/{}_PseudoInstruments.wav'.format(basename), [0], sr)
74
+ np.save('pseudo/{}_PseudoInstruments.npy'.format(basename), pseudo_inst)
75
+
76
+
77
+ if __name__ == '__main__':
78
+ main()