victan commited on
Commit
97681f9
1 Parent(s): e102e80

Upload inference.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. inference.py +172 -0
inference.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+
4
+ import librosa
5
+ import numpy as np
6
+ import soundfile as sf
7
+ import torch
8
+ from tqdm import tqdm
9
+
10
+ from lib import dataset
11
+ from lib import nets
12
+ from lib import spec_utils
13
+ from lib import utils
14
+
15
+
16
+ class Separator(object):
17
+
18
+ def __init__(self, model, device=None, batchsize=1, cropsize=256, postprocess=False):
19
+ self.model = model
20
+ self.offset = model.offset
21
+ self.device = device
22
+ self.batchsize = batchsize
23
+ self.cropsize = cropsize
24
+ self.postprocess = postprocess
25
+
26
+ def _postprocess(self, X_spec, mask):
27
+ if self.postprocess:
28
+ mask_mag = np.abs(mask)
29
+ mask_mag = spec_utils.merge_artifacts(mask_mag)
30
+ mask = mask_mag * np.exp(1.j * np.angle(mask))
31
+
32
+ y_spec = X_spec * mask
33
+ v_spec = X_spec - y_spec
34
+
35
+ return y_spec, v_spec
36
+
37
+ def _separate(self, X_spec_pad, roi_size):
38
+ X_dataset = []
39
+ patches = (X_spec_pad.shape[2] - 2 * self.offset) // roi_size
40
+ for i in range(patches):
41
+ start = i * roi_size
42
+ X_spec_crop = X_spec_pad[:, :, start:start + self.cropsize]
43
+ X_dataset.append(X_spec_crop)
44
+
45
+ X_dataset = np.asarray(X_dataset)
46
+
47
+ self.model.eval()
48
+ with torch.no_grad():
49
+ mask_list = []
50
+ # To reduce the overhead, dataloader is not used.
51
+ for i in tqdm(range(0, patches, self.batchsize)):
52
+ X_batch = X_dataset[i: i + self.batchsize]
53
+ X_batch = torch.from_numpy(X_batch).to(self.device)
54
+
55
+ mask = self.model.predict_mask(X_batch)
56
+
57
+ mask = mask.detach().cpu().numpy()
58
+ mask = np.concatenate(mask, axis=2)
59
+ mask_list.append(mask)
60
+
61
+ mask = np.concatenate(mask_list, axis=2)
62
+
63
+ return mask
64
+
65
+ def separate(self, X_spec):
66
+ n_frame = X_spec.shape[2]
67
+ pad_l, pad_r, roi_size = dataset.make_padding(n_frame, self.cropsize, self.offset)
68
+ X_spec_pad = np.pad(X_spec, ((0, 0), (0, 0), (pad_l, pad_r)), mode='constant')
69
+ X_spec_pad /= np.abs(X_spec).max()
70
+
71
+ mask = self._separate(X_spec_pad, roi_size)
72
+ mask = mask[:, :, :n_frame]
73
+
74
+ y_spec, v_spec = self._postprocess(X_spec, mask)
75
+
76
+ return y_spec, v_spec
77
+
78
+ def separate_tta(self, X_spec):
79
+ n_frame = X_spec.shape[2]
80
+ pad_l, pad_r, roi_size = dataset.make_padding(n_frame, self.cropsize, self.offset)
81
+ X_spec_pad = np.pad(X_spec, ((0, 0), (0, 0), (pad_l, pad_r)), mode='constant')
82
+ X_spec_pad /= X_spec_pad.max()
83
+
84
+ mask = self._separate(X_spec_pad, roi_size)
85
+
86
+ pad_l += roi_size // 2
87
+ pad_r += roi_size // 2
88
+ X_spec_pad = np.pad(X_spec, ((0, 0), (0, 0), (pad_l, pad_r)), mode='constant')
89
+ X_spec_pad /= X_spec_pad.max()
90
+
91
+ mask_tta = self._separate(X_spec_pad, roi_size)
92
+ mask_tta = mask_tta[:, :, roi_size // 2:]
93
+ mask = (mask[:, :, :n_frame] + mask_tta[:, :, :n_frame]) * 0.5
94
+
95
+ y_spec, v_spec = self._postprocess(X_spec, mask)
96
+
97
+ return y_spec, v_spec
98
+
99
+
100
+
101
+
102
+ def main(gpu=-1, pretrained_model='models/baseline.pth', input_file='', sr=44100, n_fft=2048,
103
+ hop_length=1024, batchsize=4, cropsize=256, output_image=False, tta=False, output_dir=""):
104
+
105
+ print('loading model...', end=' ')
106
+ device = torch.device('cpu')
107
+ if gpu >= 0:
108
+ if torch.cuda.is_available():
109
+ device = torch.device('cuda:{}'.format(gpu))
110
+ elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
111
+ device = torch.device('mps')
112
+ model = nets.CascadedNet(n_fft, hop_length, 32, 128, True)
113
+ model.load_state_dict(torch.load(pretrained_model, map_location='cpu'))
114
+ model.to(device)
115
+ print('done')
116
+
117
+ print('loading wave source...', end=' ')
118
+ print('loading wave source...', end=' ')
119
+ print("Chemin du fichier audio :", input_file) # Ajoutez cette ligne pour déboguer
120
+ X, sr = librosa.load(input_file, sr=sr, mono=False, dtype=np.float32, res_type='kaiser_fast')
121
+ basename = os.path.splitext(os.path.basename(input_file))[0]
122
+ print('done')
123
+
124
+ if X.ndim == 1:
125
+ X = np.asarray([X, X])
126
+
127
+ print('stft of wave source...', end=' ')
128
+ X_spec = spec_utils.wave_to_spectrogram(X, hop_length, n_fft)
129
+ print('done')
130
+
131
+ sp = Separator(
132
+ model=model,
133
+ device=device,
134
+ batchsize=batchsize,
135
+ cropsize=cropsize,
136
+ )
137
+
138
+ if tta:
139
+ y_spec, v_spec = sp.separate_tta(X_spec)
140
+ else:
141
+ y_spec, v_spec = sp.separate(X_spec)
142
+
143
+ print('validating output directory...', end=' ')
144
+ if output_dir != "":
145
+ output_dir = output_dir.rstrip('/') + '/'
146
+ os.makedirs(output_dir, exist_ok=True)
147
+ print('done')
148
+
149
+ print('inverse stft of instruments...', end=' ')
150
+ wave = spec_utils.spectrogram_to_wave(y_spec, hop_length=hop_length)
151
+ print('done')
152
+ sf.write('{}{}_Instruments.wav'.format(output_dir, basename), wave.T, sr)
153
+
154
+ print('inverse stft of vocals...', end=' ')
155
+ wave = spec_utils.spectrogram_to_wave(v_spec, hop_length=hop_length)
156
+ print('done')
157
+ sf.write('{}{}_Vocals_finale.wav'.format(output_dir, basename), wave.T, sr)
158
+
159
+ if output_image:
160
+ image = spec_utils.spectrogram_to_image(y_spec)
161
+ utils.imwrite('{}{}_Instruments.jpg'.format(output_dir, basename), image)
162
+
163
+ image = spec_utils.spectrogram_to_image(v_spec)
164
+ utils.imwrite('{}{}_Vocals.jpg'.format(output_dir, basename), image)
165
+
166
+
167
+
168
+ import os
169
+ # Appel de la fonction avec des paramètres
170
+
171
+ main(input_file=os.getcwd()+'/audio_gnu.wav')
172
+