Hev832 commited on
Commit
4036374
1 Parent(s): f64e5e8

Create mdx.py

Browse files
Files changed (1) hide show
  1. mdx.py +289 -0
mdx.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import hashlib
3
+ import os
4
+ import queue
5
+ import threading
6
+ import warnings
7
+
8
+ import librosa
9
+ import numpy as np
10
+ import onnxruntime as ort
11
+ import soundfile as sf
12
+ import torch
13
+ from tqdm import tqdm
14
+
15
+ warnings.filterwarnings("ignore")
16
+ stem_naming = {'Vocals': 'Instrumental', 'Other': 'Instruments', 'Instrumental': 'Vocals', 'Drums': 'Drumless', 'Bass': 'Bassless'}
17
+
18
+
19
+ class MDXModel:
20
+ def __init__(self, device, dim_f, dim_t, n_fft, hop=1024, stem_name=None, compensation=1.000):
21
+ self.dim_f = dim_f
22
+ self.dim_t = dim_t
23
+ self.dim_c = 4
24
+ self.n_fft = n_fft
25
+ self.hop = hop
26
+ self.stem_name = stem_name
27
+ self.compensation = compensation
28
+
29
+ self.n_bins = self.n_fft // 2 + 1
30
+ self.chunk_size = hop * (self.dim_t - 1)
31
+ self.window = torch.hann_window(window_length=self.n_fft, periodic=True).to(device)
32
+
33
+ out_c = self.dim_c
34
+
35
+ self.freq_pad = torch.zeros([1, out_c, self.n_bins - self.dim_f, self.dim_t]).to(device)
36
+
37
+ def stft(self, x):
38
+ x = x.reshape([-1, self.chunk_size])
39
+ x = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True, return_complex=True)
40
+ x = torch.view_as_real(x)
41
+ x = x.permute([0, 3, 1, 2])
42
+ x = x.reshape([-1, 2, 2, self.n_bins, self.dim_t]).reshape([-1, 4, self.n_bins, self.dim_t])
43
+ return x[:, :, :self.dim_f]
44
+
45
+ def istft(self, x, freq_pad=None):
46
+ freq_pad = self.freq_pad.repeat([x.shape[0], 1, 1, 1]) if freq_pad is None else freq_pad
47
+ x = torch.cat([x, freq_pad], -2)
48
+ # c = 4*2 if self.target_name=='*' else 2
49
+ x = x.reshape([-1, 2, 2, self.n_bins, self.dim_t]).reshape([-1, 2, self.n_bins, self.dim_t])
50
+ x = x.permute([0, 2, 3, 1])
51
+ x = x.contiguous()
52
+ x = torch.view_as_complex(x)
53
+ x = torch.istft(x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True)
54
+ return x.reshape([-1, 2, self.chunk_size])
55
+
56
+
57
+ class MDX:
58
+ DEFAULT_SR = 44100
59
+ # Unit: seconds
60
+ DEFAULT_CHUNK_SIZE = 0 * DEFAULT_SR
61
+ DEFAULT_MARGIN_SIZE = 1 * DEFAULT_SR
62
+
63
+ DEFAULT_PROCESSOR = 0
64
+
65
+ def __init__(self, model_path: str, params: MDXModel, processor=DEFAULT_PROCESSOR):
66
+
67
+ # Set the device and the provider (CPU or CUDA)
68
+ self.device = torch.device(f'cuda:{processor}') if processor >= 0 else torch.device('cpu')
69
+ self.provider = ['CUDAExecutionProvider'] if processor >= 0 else ['CPUExecutionProvider']
70
+
71
+ self.model = params
72
+
73
+ # Load the ONNX model using ONNX Runtime
74
+ self.ort = ort.InferenceSession(model_path, providers=self.provider)
75
+ # Preload the model for faster performance
76
+ self.ort.run(None, {'input': torch.rand(1, 4, params.dim_f, params.dim_t).numpy()})
77
+ self.process = lambda spec: self.ort.run(None, {'input': spec.cpu().numpy()})[0]
78
+
79
+ self.prog = None
80
+
81
+ @staticmethod
82
+ def get_hash(model_path):
83
+ try:
84
+ with open(model_path, 'rb') as f:
85
+ f.seek(- 10000 * 1024, 2)
86
+ model_hash = hashlib.md5(f.read()).hexdigest()
87
+ except:
88
+ model_hash = hashlib.md5(open(model_path, 'rb').read()).hexdigest()
89
+
90
+ return model_hash
91
+
92
+ @staticmethod
93
+ def segment(wave, combine=True, chunk_size=DEFAULT_CHUNK_SIZE, margin_size=DEFAULT_MARGIN_SIZE):
94
+ """
95
+ Segment or join segmented wave array
96
+
97
+ Args:
98
+ wave: (np.array) Wave array to be segmented or joined
99
+ combine: (bool) If True, combines segmented wave array. If False, segments wave array.
100
+ chunk_size: (int) Size of each segment (in samples)
101
+ margin_size: (int) Size of margin between segments (in samples)
102
+
103
+ Returns:
104
+ numpy array: Segmented or joined wave array
105
+ """
106
+
107
+ if combine:
108
+ processed_wave = None # Initializing as None instead of [] for later numpy array concatenation
109
+ for segment_count, segment in enumerate(wave):
110
+ start = 0 if segment_count == 0 else margin_size
111
+ end = None if segment_count == len(wave) - 1 else -margin_size
112
+ if margin_size == 0:
113
+ end = None
114
+ if processed_wave is None: # Create array for first segment
115
+ processed_wave = segment[:, start:end]
116
+ else: # Concatenate to existing array for subsequent segments
117
+ processed_wave = np.concatenate((processed_wave, segment[:, start:end]), axis=-1)
118
+
119
+ else:
120
+ processed_wave = []
121
+ sample_count = wave.shape[-1]
122
+
123
+ if chunk_size <= 0 or chunk_size > sample_count:
124
+ chunk_size = sample_count
125
+
126
+ if margin_size > chunk_size:
127
+ margin_size = chunk_size
128
+
129
+ for segment_count, skip in enumerate(range(0, sample_count, chunk_size)):
130
+
131
+ margin = 0 if segment_count == 0 else margin_size
132
+ end = min(skip + chunk_size + margin_size, sample_count)
133
+ start = skip - margin
134
+
135
+ cut = wave[:, start:end].copy()
136
+ processed_wave.append(cut)
137
+
138
+ if end == sample_count:
139
+ break
140
+
141
+ return processed_wave
142
+
143
+ def pad_wave(self, wave):
144
+ """
145
+ Pad the wave array to match the required chunk size
146
+
147
+ Args:
148
+ wave: (np.array) Wave array to be padded
149
+
150
+ Returns:
151
+ tuple: (padded_wave, pad, trim)
152
+ - padded_wave: Padded wave array
153
+ - pad: Number of samples that were padded
154
+ - trim: Number of samples that were trimmed
155
+ """
156
+ n_sample = wave.shape[1]
157
+ trim = self.model.n_fft // 2
158
+ gen_size = self.model.chunk_size - 2 * trim
159
+ pad = gen_size - n_sample % gen_size
160
+
161
+ # Padded wave
162
+ wave_p = np.concatenate((np.zeros((2, trim)), wave, np.zeros((2, pad)), np.zeros((2, trim))), 1)
163
+
164
+ mix_waves = []
165
+ for i in range(0, n_sample + pad, gen_size):
166
+ waves = np.array(wave_p[:, i:i + self.model.chunk_size])
167
+ mix_waves.append(waves)
168
+
169
+ mix_waves = torch.tensor(mix_waves, dtype=torch.float32).to(self.device)
170
+
171
+ return mix_waves, pad, trim
172
+
173
+ def _process_wave(self, mix_waves, trim, pad, q: queue.Queue, _id: int):
174
+ """
175
+ Process each wave segment in a multi-threaded environment
176
+
177
+ Args:
178
+ mix_waves: (torch.Tensor) Wave segments to be processed
179
+ trim: (int) Number of samples trimmed during padding
180
+ pad: (int) Number of samples padded during padding
181
+ q: (queue.Queue) Queue to hold the processed wave segments
182
+ _id: (int) Identifier of the processed wave segment
183
+
184
+ Returns:
185
+ numpy array: Processed wave segment
186
+ """
187
+ mix_waves = mix_waves.split(1)
188
+ with torch.no_grad():
189
+ pw = []
190
+ for mix_wave in mix_waves:
191
+ self.prog.update()
192
+ spec = self.model.stft(mix_wave)
193
+ processed_spec = torch.tensor(self.process(spec))
194
+ processed_wav = self.model.istft(processed_spec.to(self.device))
195
+ processed_wav = processed_wav[:, :, trim:-trim].transpose(0, 1).reshape(2, -1).cpu().numpy()
196
+ pw.append(processed_wav)
197
+ processed_signal = np.concatenate(pw, axis=-1)[:, :-pad]
198
+ q.put({_id: processed_signal})
199
+ return processed_signal
200
+
201
+ def process_wave(self, wave: np.array, mt_threads=1):
202
+ """
203
+ Process the wave array in a multi-threaded environment
204
+
205
+ Args:
206
+ wave: (np.array) Wave array to be processed
207
+ mt_threads: (int) Number of threads to be used for processing
208
+
209
+ Returns:
210
+ numpy array: Processed wave array
211
+ """
212
+ self.prog = tqdm(total=0)
213
+ chunk = wave.shape[-1] // mt_threads
214
+ waves = self.segment(wave, False, chunk)
215
+
216
+ # Create a queue to hold the processed wave segments
217
+ q = queue.Queue()
218
+ threads = []
219
+ for c, batch in enumerate(waves):
220
+ mix_waves, pad, trim = self.pad_wave(batch)
221
+ self.prog.total = len(mix_waves) * mt_threads
222
+ thread = threading.Thread(target=self._process_wave, args=(mix_waves, trim, pad, q, c))
223
+ thread.start()
224
+ threads.append(thread)
225
+ for thread in threads:
226
+ thread.join()
227
+ self.prog.close()
228
+
229
+ processed_batches = []
230
+ while not q.empty():
231
+ processed_batches.append(q.get())
232
+ processed_batches = [list(wave.values())[0] for wave in
233
+ sorted(processed_batches, key=lambda d: list(d.keys())[0])]
234
+ assert len(processed_batches) == len(waves), 'Incomplete processed batches, please reduce batch size!'
235
+ return self.segment(processed_batches, True, chunk)
236
+
237
+
238
+ def run_mdx(model_params, output_dir, model_path, filename, exclude_main=False, exclude_inversion=False, suffix=None, invert_suffix=None, denoise=False, keep_orig=True, m_threads=2):
239
+ device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
240
+
241
+ device_properties = torch.cuda.get_device_properties(device)
242
+ vram_gb = device_properties.total_memory / 1024**3
243
+ m_threads = 1 if vram_gb < 8 else 2
244
+
245
+ model_hash = MDX.get_hash(model_path)
246
+ mp = model_params.get(model_hash)
247
+ model = MDXModel(
248
+ device,
249
+ dim_f=mp["mdx_dim_f_set"],
250
+ dim_t=2 ** mp["mdx_dim_t_set"],
251
+ n_fft=mp["mdx_n_fft_scale_set"],
252
+ stem_name=mp["primary_stem"],
253
+ compensation=mp["compensate"]
254
+ )
255
+
256
+ mdx_sess = MDX(model_path, model)
257
+ wave, sr = librosa.load(filename, mono=False, sr=44100)
258
+ # normalizing input wave gives better output
259
+ peak = max(np.max(wave), abs(np.min(wave)))
260
+ wave /= peak
261
+ if denoise:
262
+ wave_processed = -(mdx_sess.process_wave(-wave, m_threads)) + (mdx_sess.process_wave(wave, m_threads))
263
+ wave_processed *= 0.5
264
+ else:
265
+ wave_processed = mdx_sess.process_wave(wave, m_threads)
266
+ # return to previous peak
267
+ wave_processed *= peak
268
+ stem_name = model.stem_name if suffix is None else suffix
269
+
270
+ main_filepath = None
271
+ if not exclude_main:
272
+ main_filepath = os.path.join(output_dir, f"{os.path.basename(os.path.splitext(filename)[0])}_{stem_name}.wav")
273
+ sf.write(main_filepath, wave_processed.T, sr)
274
+
275
+ invert_filepath = None
276
+ if not exclude_inversion:
277
+ diff_stem_name = stem_naming.get(stem_name) if invert_suffix is None else invert_suffix
278
+ stem_name = f"{stem_name}_diff" if diff_stem_name is None else diff_stem_name
279
+ invert_filepath = os.path.join(output_dir, f"{os.path.basename(os.path.splitext(filename)[0])}_{stem_name}.wav")
280
+ sf.write(invert_filepath, (-wave_processed.T * model.compensation) + wave.T, sr)
281
+
282
+ if not keep_orig:
283
+ os.remove(filename)
284
+
285
+ del mdx_sess, wave_processed, wave
286
+ if torch.cuda.is_available():
287
+ torch.cuda.empty_cache()
288
+ gc.collect()
289
+ return main_filepath, invert_filepath