Eddycrack864 commited on
Commit
ede719c
1 Parent(s): 240c2c0

Upload separate.py

Browse files
Files changed (1) hide show
  1. separate.py +355 -872
separate.py CHANGED
@@ -6,15 +6,13 @@ from demucs.model_v2 import auto_load_demucs_model_v2
6
  from demucs.pretrained import get_model as _gm
7
  from demucs.utils import apply_model_v1
8
  from demucs.utils import apply_model_v2
9
- from lib_v5.tfc_tdf_v3 import TFC_TDF_net, STFT
10
  from lib_v5 import spec_utils
11
  from lib_v5.vr_network import nets
12
  from lib_v5.vr_network import nets_new
13
- from lib_v5.vr_network.model_param_init import ModelParameters
14
  from pathlib import Path
15
  from gui_data.constants import *
16
  from gui_data.error_handling import *
17
- from scipy import signal
18
  import audioread
19
  import gzip
20
  import librosa
@@ -26,85 +24,31 @@ import torch
26
  import warnings
27
  import pydub
28
  import soundfile as sf
 
29
  import lib_v5.mdxnet as MdxnetSet
30
- import math
31
- #import random
32
- from onnx import load
33
- from onnx2pytorch import ConvertModel
34
- import gc
35
-
36
  if TYPE_CHECKING:
37
  from UVR import ModelData
38
 
39
- # if not is_macos:
40
- # import torch_directml
41
-
42
- mps_available = torch.backends.mps.is_available() if is_macos else False
43
- cuda_available = torch.cuda.is_available()
44
-
45
- # def get_gpu_info():
46
- # directml_device, directml_available = DIRECTML_DEVICE, False
47
-
48
- # if not is_macos:
49
- # directml_available = torch_directml.is_available()
50
-
51
- # if directml_available:
52
- # directml_device = str(torch_directml.device()).partition(":")[0]
53
-
54
- # return directml_device, directml_available
55
-
56
- # DIRECTML_DEVICE, directml_available = get_gpu_info()
57
-
58
- def clear_gpu_cache():
59
- gc.collect()
60
- if is_macos:
61
- torch.mps.empty_cache()
62
- else:
63
- torch.cuda.empty_cache()
64
-
65
  warnings.filterwarnings("ignore")
66
  cpu = torch.device('cpu')
67
 
68
  class SeperateAttributes:
69
- def __init__(self, model_data: ModelData,
70
- process_data: dict,
71
- main_model_primary_stem_4_stem=None,
72
- main_process_method=None,
73
- is_return_dual=True,
74
- main_model_primary=None,
75
- vocal_stem_path=None,
76
- master_inst_source=None,
77
- master_vocal_source=None):
78
 
79
  self.list_all_models: list
80
  self.process_data = process_data
81
  self.progress_value = 0
82
  self.set_progress_bar = process_data['set_progress_bar']
83
  self.write_to_console = process_data['write_to_console']
84
- if vocal_stem_path:
85
- self.audio_file, self.audio_file_base = vocal_stem_path
86
- self.audio_file_base_voc_split = lambda stem, split:os.path.join(self.export_path, f'{self.audio_file_base.replace("_(Vocals)", "")}_({stem}_{split}).wav')
87
- else:
88
- self.audio_file = process_data['audio_file']
89
- self.audio_file_base = process_data['audio_file_base']
90
- self.audio_file_base_voc_split = None
91
  self.export_path = process_data['export_path']
92
  self.cached_source_callback = process_data['cached_source_callback']
93
  self.cached_model_source_holder = process_data['cached_model_source_holder']
94
  self.is_4_stem_ensemble = process_data['is_4_stem_ensemble']
95
  self.list_all_models = process_data['list_all_models']
96
  self.process_iteration = process_data['process_iteration']
97
- self.is_return_dual = is_return_dual
98
- self.is_pitch_change = model_data.is_pitch_change
99
- self.semitone_shift = model_data.semitone_shift
100
- self.is_match_frequency_pitch = model_data.is_match_frequency_pitch
101
- self.overlap = model_data.overlap
102
- self.overlap_mdx = model_data.overlap_mdx
103
- self.overlap_mdx23 = model_data.overlap_mdx23
104
- self.is_mdx_combine_stems = model_data.is_mdx_combine_stems
105
- self.is_mdx_c = model_data.is_mdx_c
106
- self.mdx_c_configs = model_data.mdx_c_configs
107
- self.mdxnet_stem_select = model_data.mdxnet_stem_select
108
  self.mixer_path = model_data.mixer_path
109
  self.model_samplerate = model_data.model_samplerate
110
  self.model_capacity = model_data.model_capacity
@@ -126,11 +70,9 @@ class SeperateAttributes:
126
  self.is_ensemble_mode = model_data.is_ensemble_mode
127
  self.secondary_model = model_data.secondary_model #
128
  self.primary_model_primary_stem = model_data.primary_model_primary_stem
129
- self.primary_stem_native = model_data.primary_stem_native
130
  self.primary_stem = model_data.primary_stem #
131
  self.secondary_stem = model_data.secondary_stem #
132
  self.is_invert_spec = model_data.is_invert_spec #
133
- self.is_deverb_vocals = model_data.is_deverb_vocals
134
  self.is_mixer_mode = model_data.is_mixer_mode #
135
  self.secondary_model_scale = model_data.secondary_model_scale #
136
  self.is_demucs_pre_proc_model_inst_mix = model_data.is_demucs_pre_proc_model_inst_mix #
@@ -140,87 +82,49 @@ class SeperateAttributes:
140
  self.secondary_source = None
141
  self.secondary_source_primary = None
142
  self.secondary_source_secondary = None
143
- self.main_model_primary_stem_4_stem = main_model_primary_stem_4_stem
144
- self.main_model_primary = main_model_primary
145
- self.ensemble_primary_stem = model_data.ensemble_primary_stem
146
- self.is_multi_stem_ensemble = model_data.is_multi_stem_ensemble
147
- self.is_other_gpu = False
148
- self.is_deverb = True
149
- self.DENOISER_MODEL = model_data.DENOISER_MODEL
150
- self.DEVERBER_MODEL = model_data.DEVERBER_MODEL
151
- self.is_source_swap = False
152
- self.vocal_split_model = model_data.vocal_split_model
153
- self.is_vocal_split_model = model_data.is_vocal_split_model
154
- self.master_vocal_path = None
155
- self.set_master_inst_source = None
156
- self.master_inst_source = master_inst_source
157
- self.master_vocal_source = master_vocal_source
158
- self.is_save_inst_vocal_splitter = isinstance(master_inst_source, np.ndarray) and model_data.is_save_inst_vocal_splitter
159
- self.is_inst_only_voc_splitter = model_data.is_inst_only_voc_splitter
160
- self.is_karaoke = model_data.is_karaoke
161
- self.is_bv_model = model_data.is_bv_model
162
- self.is_bv_model_rebalenced = model_data.bv_model_rebalance and self.is_vocal_split_model
163
- self.is_sec_bv_rebalance = model_data.is_sec_bv_rebalance
164
- self.stem_path_init = os.path.join(self.export_path, f'{self.audio_file_base}_({self.secondary_stem}).wav')
165
- self.deverb_vocal_opt = model_data.deverb_vocal_opt
166
- self.is_save_vocal_only = model_data.is_save_vocal_only
167
- self.device = cpu
168
- self.run_type = ['CPUExecutionProvider']
169
- self.is_opencl = False
170
- self.device_set = model_data.device_set
171
- self.is_use_opencl = model_data.is_use_opencl
172
-
173
- if self.is_inst_only_voc_splitter or self.is_sec_bv_rebalance:
174
- self.is_primary_stem_only = False
175
- self.is_secondary_stem_only = False
176
-
177
- if main_model_primary and self.is_multi_stem_ensemble:
178
- self.primary_stem, self.secondary_stem = main_model_primary, secondary_stem(main_model_primary)
179
 
180
- if self.is_gpu_conversion >= 0:
181
- if mps_available:
182
- self.device, self.is_other_gpu = 'mps', True
183
- else:
184
- device_prefix = None
185
- if self.device_set != DEFAULT:
186
- device_prefix = CUDA_DEVICE#DIRECTML_DEVICE if self.is_use_opencl and directml_available else CUDA_DEVICE
 
 
 
 
 
187
 
188
- # if directml_available and self.is_use_opencl:
189
- # self.device = torch_directml.device() if not device_prefix else f'{device_prefix}:{self.device_set}'
190
- # self.is_other_gpu = True
191
- if cuda_available:# and not self.is_use_opencl:
192
- self.device = CUDA_DEVICE if not device_prefix else f'{device_prefix}:{self.device_set}'
193
- self.run_type = ['CUDAExecutionProvider']
194
 
195
  if model_data.process_method == MDX_ARCH_TYPE:
196
  self.is_mdx_ckpt = model_data.is_mdx_ckpt
197
  self.primary_model_name, self.primary_sources = self.cached_source_callback(MDX_ARCH_TYPE, model_name=self.model_basename)
198
- self.is_denoise = model_data.is_denoise#
199
- self.is_denoise_model = model_data.is_denoise_model#
200
- self.is_mdx_c_seg_def = model_data.is_mdx_c_seg_def#
201
  self.mdx_batch_size = model_data.mdx_batch_size
202
  self.compensate = model_data.compensate
203
- self.mdx_segment_size = model_data.mdx_segment_size
204
-
205
- if self.is_mdx_c:
206
- if not self.is_4_stem_ensemble:
207
- self.primary_stem = model_data.ensemble_primary_stem if process_data['is_ensemble_master'] else model_data.primary_stem
208
- self.secondary_stem = model_data.ensemble_secondary_stem if process_data['is_ensemble_master'] else model_data.secondary_stem
209
- else:
210
- self.dim_f, self.dim_t = model_data.mdx_dim_f_set, 2**model_data.mdx_dim_t_set
211
-
212
- self.check_label_secondary_stem_runs()
213
  self.n_fft = model_data.mdx_n_fft_scale_set
214
  self.chunks = model_data.chunks
215
  self.margin = model_data.margin
216
  self.adjust = 1
217
  self.dim_c = 4
218
  self.hop = 1024
 
 
 
 
 
219
 
220
  if model_data.process_method == DEMUCS_ARCH_TYPE:
221
  self.demucs_stems = model_data.demucs_stems if not main_process_method in [MDX_ARCH_TYPE, VR_ARCH_TYPE] else None
222
  self.secondary_model_4_stem = model_data.secondary_model_4_stem
223
  self.secondary_model_4_stem_scale = model_data.secondary_model_4_stem_scale
 
 
224
  self.is_chunk_demucs = model_data.is_chunk_demucs
225
  self.segment = model_data.segment
226
  self.demucs_version = model_data.demucs_version
@@ -229,37 +133,28 @@ class SeperateAttributes:
229
  self.is_demucs_combine_stems = model_data.is_demucs_combine_stems
230
  self.demucs_stem_count = model_data.demucs_stem_count
231
  self.pre_proc_model = model_data.pre_proc_model
232
- self.device = cpu if self.is_other_gpu and not self.demucs_version in [DEMUCS_V3, DEMUCS_V4] else self.device
233
-
234
- self.primary_stem = model_data.ensemble_primary_stem if process_data['is_ensemble_master'] else model_data.primary_stem
235
- self.secondary_stem = model_data.ensemble_secondary_stem if process_data['is_ensemble_master'] else model_data.secondary_stem
236
-
237
- if (self.is_multi_stem_ensemble or self.is_4_stem_ensemble) and not self.is_secondary_model:
238
- self.is_return_dual = False
239
 
240
- if self.is_multi_stem_ensemble and main_model_primary:
241
- self.is_4_stem_ensemble = False
242
- if main_model_primary in self.demucs_source_map.keys():
243
- self.primary_stem = main_model_primary
244
- self.secondary_stem = secondary_stem(main_model_primary)
245
- elif secondary_stem(main_model_primary) in self.demucs_source_map.keys():
246
- self.primary_stem = secondary_stem(main_model_primary)
247
- self.secondary_stem = main_model_primary
248
-
249
  if self.is_secondary_model and not process_data['is_ensemble_master']:
250
  if not self.demucs_stem_count == 2 and model_data.primary_model_primary_stem == INST_STEM:
251
  self.primary_stem = VOCAL_STEM
252
  self.secondary_stem = INST_STEM
253
  else:
254
  self.primary_stem = model_data.primary_model_primary_stem
255
- self.secondary_stem = secondary_stem(self.primary_stem)
256
-
 
 
 
 
 
 
 
257
  self.shifts = model_data.shifts
258
  self.is_split_mode = model_data.is_split_mode if not self.demucs_version == DEMUCS_V4 else True
 
259
  self.primary_model_name, self.primary_sources = self.cached_source_callback(DEMUCS_ARCH_TYPE, model_name=self.model_basename)
260
 
261
  if model_data.process_method == VR_ARCH_TYPE:
262
- self.check_label_secondary_stem_runs()
263
  self.primary_model_name, self.primary_sources = self.cached_source_callback(VR_ARCH_TYPE, model_name=self.model_basename)
264
  self.mp = model_data.vr_model_param
265
  self.high_end_process = model_data.is_high_end_process
@@ -269,44 +164,28 @@ class SeperateAttributes:
269
  self.batch_size = model_data.batch_size
270
  self.window_size = model_data.window_size
271
  self.input_high_end_h = None
272
- self.input_high_end = None
273
  self.post_process_threshold = model_data.post_process_threshold
274
  self.aggressiveness = {'value': model_data.aggression_setting,
275
  'split_bin': self.mp.param['band'][1]['crop_stop'],
276
  'aggr_correction': self.mp.param.get('aggr_correction')}
277
 
278
- def check_label_secondary_stem_runs(self):
279
-
280
- # For ensemble master that's not a 4-stem ensemble, and not mdx_c
281
- if self.process_data['is_ensemble_master'] and not self.is_4_stem_ensemble and not self.is_mdx_c:
282
- if self.ensemble_primary_stem != self.primary_stem:
283
- self.is_primary_stem_only, self.is_secondary_stem_only = self.is_secondary_stem_only, self.is_primary_stem_only
284
-
285
- # For secondary models
286
- if self.is_pre_proc_model or self.is_secondary_model:
287
- self.is_primary_stem_only = False
288
- self.is_secondary_stem_only = False
289
-
290
  def start_inference_console_write(self):
291
- if self.is_secondary_model and not self.is_pre_proc_model and not self.is_vocal_split_model:
 
292
  self.write_to_console(INFERENCE_STEP_2_SEC(self.process_method, self.model_basename))
293
 
294
  if self.is_pre_proc_model:
295
  self.write_to_console(INFERENCE_STEP_2_PRE(self.process_method, self.model_basename))
296
-
297
- if self.is_vocal_split_model:
298
- self.write_to_console(INFERENCE_STEP_2_VOC_S(self.process_method, self.model_basename))
299
 
300
  def running_inference_console_write(self, is_no_write=False):
 
301
  self.write_to_console(DONE, base_text='') if not is_no_write else None
302
  self.set_progress_bar(0.05) if not is_no_write else None
303
 
304
- if self.is_secondary_model and not self.is_pre_proc_model and not self.is_vocal_split_model:
305
  self.write_to_console(INFERENCE_STEP_1_SEC)
306
  elif self.is_pre_proc_model:
307
  self.write_to_console(INFERENCE_STEP_1_PRE)
308
- elif self.is_vocal_split_model:
309
- self.write_to_console(INFERENCE_STEP_1_VOC_S)
310
  else:
311
  self.write_to_console(INFERENCE_STEP_1)
312
 
@@ -319,14 +198,19 @@ class SeperateAttributes:
319
 
320
  self.set_progress_bar(0.1, (0.8/length*self.progress_value))
321
 
322
- def load_cached_sources(self):
323
 
324
  if self.is_secondary_model and not self.is_pre_proc_model:
325
  self.write_to_console(INFERENCE_STEP_2_SEC_CACHED_MODOEL(self.process_method, self.model_basename))
326
  elif self.is_pre_proc_model:
327
  self.write_to_console(INFERENCE_STEP_2_PRE_CACHED_MODOEL(self.process_method, self.model_basename))
328
  else:
329
- self.write_to_console(INFERENCE_STEP_2_PRIMARY_CACHED, "")
 
 
 
 
 
330
 
331
  def cache_source(self, secondary_sources):
332
 
@@ -341,142 +225,49 @@ class SeperateAttributes:
341
 
342
  if self.process_method == DEMUCS_ARCH_TYPE:
343
  self.cached_model_source_holder(DEMUCS_ARCH_TYPE, secondary_sources, self.model_basename)
344
-
345
- def process_vocal_split_chain(self, sources: dict):
346
-
347
- def is_valid_vocal_split_condition(master_vocal_source):
348
- """Checks if conditions for vocal split processing are met."""
349
- conditions = [
350
- isinstance(master_vocal_source, np.ndarray),
351
- self.vocal_split_model,
352
- not self.is_ensemble_mode,
353
- not self.is_karaoke,
354
- not self.is_bv_model
355
- ]
356
- return all(conditions)
357
-
358
- # Retrieve sources from the dictionary with default fallbacks
359
- master_inst_source = sources.get(INST_STEM, None)
360
- master_vocal_source = sources.get(VOCAL_STEM, None)
361
-
362
- # Process the vocal split chain if conditions are met
363
- if is_valid_vocal_split_condition(master_vocal_source):
364
- process_chain_model(
365
- self.vocal_split_model,
366
- self.process_data,
367
- vocal_stem_path=self.master_vocal_path,
368
- master_vocal_source=master_vocal_source,
369
- master_inst_source=master_inst_source
370
- )
371
-
372
- def process_secondary_stem(self, stem_source, secondary_model_source=None, model_scale=None):
373
  if not self.is_secondary_model:
374
- if self.is_secondary_model_activated and isinstance(secondary_model_source, np.ndarray):
375
- secondary_model_scale = model_scale if model_scale else self.secondary_model_scale
376
- stem_source = spec_utils.average_dual_sources(stem_source, secondary_model_source, secondary_model_scale)
377
-
378
- return stem_source
379
-
380
- def final_process(self, stem_path, source, secondary_source, stem_name, samplerate):
381
- source = self.process_secondary_stem(source, secondary_source)
382
- self.write_audio(stem_path, source, samplerate, stem_name=stem_name)
383
-
384
- return {stem_name: source}
385
-
386
- def write_audio(self, stem_path: str, stem_source, samplerate, stem_name=None):
387
-
388
- def save_audio_file(path, source):
389
- source = spec_utils.normalize(source, self.is_normalization)
390
- sf.write(path, source, samplerate, subtype=self.wav_type_set)
391
-
392
- if is_not_ensemble:
393
- save_format(path, self.save_format, self.mp3_bit_set)
394
-
395
- def save_voc_split_instrumental(stem_name, stem_source, is_inst_invert=False):
396
- inst_stem_name = "Instrumental (With Lead Vocals)" if stem_name == LEAD_VOCAL_STEM else "Instrumental (With Backing Vocals)"
397
- inst_stem_path_name = LEAD_VOCAL_STEM_I if stem_name == LEAD_VOCAL_STEM else BV_VOCAL_STEM_I
398
- inst_stem_path = self.audio_file_base_voc_split(INST_STEM, inst_stem_path_name)
399
- stem_source = -stem_source if is_inst_invert else stem_source
400
- inst_stem_source = spec_utils.combine_arrarys([self.master_inst_source, stem_source], is_swap=True)
401
- save_with_message(inst_stem_path, inst_stem_name, inst_stem_source)
402
-
403
- def save_voc_split_vocal(stem_name, stem_source):
404
- voc_split_stem_name = LEAD_VOCAL_STEM_LABEL if stem_name == LEAD_VOCAL_STEM else BV_VOCAL_STEM_LABEL
405
- voc_split_stem_path = self.audio_file_base_voc_split(VOCAL_STEM, stem_name)
406
- save_with_message(voc_split_stem_path, voc_split_stem_name, stem_source)
407
-
408
- def save_with_message(stem_path, stem_name, stem_source):
409
- is_deverb = self.is_deverb_vocals and (
410
- self.deverb_vocal_opt == stem_name or
411
- (self.deverb_vocal_opt == 'ALL' and
412
- (stem_name == VOCAL_STEM or stem_name == LEAD_VOCAL_STEM_LABEL or stem_name == BV_VOCAL_STEM_LABEL)))
413
-
414
- self.write_to_console(f'{SAVING_STEM[0]}{stem_name}{SAVING_STEM[1]}')
415
 
416
- if is_deverb and is_not_ensemble:
417
- deverb_vocals(stem_path, stem_source)
418
 
419
- save_audio_file(stem_path, stem_source)
420
  self.write_to_console(DONE, base_text='')
421
-
422
- def deverb_vocals(stem_path:str, stem_source):
423
- self.write_to_console(INFERENCE_STEP_DEVERBING, base_text='')
424
- stem_source_deverbed, stem_source_2 = vr_denoiser(stem_source, self.device, is_deverber=True, model_path=self.DEVERBER_MODEL)
425
- save_audio_file(stem_path.replace(".wav", "_deverbed.wav"), stem_source_deverbed)
426
- save_audio_file(stem_path.replace(".wav", "_reverb_only.wav"), stem_source_2)
427
-
428
- is_bv_model_lead = (self.is_bv_model_rebalenced and self.is_vocal_split_model and stem_name == LEAD_VOCAL_STEM)
429
- is_bv_rebalance_lead = (self.is_bv_model_rebalenced and self.is_vocal_split_model and stem_name == BV_VOCAL_STEM)
430
- is_no_vocal_save = self.is_inst_only_voc_splitter and (stem_name == VOCAL_STEM or stem_name == BV_VOCAL_STEM or stem_name == LEAD_VOCAL_STEM) or is_bv_model_lead
431
- is_not_ensemble = (not self.is_ensemble_mode or self.is_vocal_split_model)
432
- is_do_not_save_inst = (self.is_save_vocal_only and self.is_sec_bv_rebalance and stem_name == INST_STEM)
433
-
434
- if is_bv_rebalance_lead:
435
- master_voc_source = spec_utils.match_array_shapes(self.master_vocal_source, stem_source, is_swap=True)
436
- bv_rebalance_lead_source = stem_source-master_voc_source
437
-
438
- if not is_bv_model_lead and not is_do_not_save_inst:
439
- if self.is_vocal_split_model or not self.is_secondary_model:
440
- if self.is_vocal_split_model and not self.is_inst_only_voc_splitter:
441
- save_voc_split_vocal(stem_name, stem_source)
442
- if is_bv_rebalance_lead:
443
- save_voc_split_vocal(LEAD_VOCAL_STEM, bv_rebalance_lead_source)
444
- else:
445
- if not is_no_vocal_save:
446
- save_with_message(stem_path, stem_name, stem_source)
447
-
448
- if self.is_save_inst_vocal_splitter and not self.is_save_vocal_only:
449
- save_voc_split_instrumental(stem_name, stem_source)
450
- if is_bv_rebalance_lead:
451
- save_voc_split_instrumental(LEAD_VOCAL_STEM, bv_rebalance_lead_source, is_inst_invert=True)
452
-
453
- self.set_progress_bar(0.95)
454
 
455
- if stem_name == VOCAL_STEM:
456
- self.master_vocal_path = stem_path
457
-
458
- def pitch_fix(self, source, sr_pitched, org_mix):
459
- semitone_shift = self.semitone_shift
460
- source = spec_utils.change_pitch_semitones(source, sr_pitched, semitone_shift=semitone_shift)[0]
461
- source = spec_utils.match_array_shapes(source, org_mix)
462
- return source
463
-
464
- def match_frequency_pitch(self, mix):
465
- source = mix
466
- if self.is_match_frequency_pitch and self.is_pitch_change:
467
- source, sr_pitched = spec_utils.change_pitch_semitones(mix, 44100, semitone_shift=-self.semitone_shift)
468
- source = self.pitch_fix(source, sr_pitched, mix)
469
-
470
- return source
 
 
 
 
471
 
472
  class SeperateMDX(SeperateAttributes):
473
 
474
  def seperate(self):
475
  samplerate = 44100
476
-
477
- if self.primary_model_name == self.model_basename and isinstance(self.primary_sources, tuple):
478
- mix, source = self.primary_sources
479
- self.load_cached_sources()
480
  else:
481
  self.start_inference_console_write()
482
 
@@ -486,145 +277,105 @@ class SeperateMDX(SeperateAttributes):
486
  separator = MdxnetSet.ConvTDFNet(**model_params)
487
  self.model_run = separator.load_from_checkpoint(self.model_path).to(self.device).eval()
488
  else:
489
- if self.mdx_segment_size == self.dim_t and not self.is_other_gpu:
490
- ort_ = ort.InferenceSession(self.model_path, providers=self.run_type)
491
- self.model_run = lambda spek:ort_.run(None, {'input': spek.cpu().numpy()})[0]
492
- else:
493
- self.model_run = ConvertModel(load(self.model_path))
494
- self.model_run.to(self.device).eval()
495
 
 
496
  self.running_inference_console_write()
497
- mix = prepare_mix(self.audio_file)
498
-
499
- source = self.demix(mix)
500
-
501
- if not self.is_vocal_split_model:
502
- self.cache_source((mix, source))
503
  self.write_to_console(DONE, base_text='')
504
 
505
- mdx_net_cut = True if self.primary_stem in MDX_NET_FREQ_CUT and self.is_match_frequency_pitch else False
506
-
507
- if self.is_secondary_model_activated and self.secondary_model:
508
- self.secondary_source_primary, self.secondary_source_secondary = process_secondary_model(self.secondary_model, self.process_data, main_process_method=self.process_method, main_model_primary=self.primary_stem)
509
 
 
 
 
 
 
 
 
 
510
  if not self.is_primary_stem_only:
 
511
  secondary_stem_path = os.path.join(self.export_path, f'{self.audio_file_base}_({self.secondary_stem}).wav')
512
  if not isinstance(self.secondary_source, np.ndarray):
513
- raw_mix = self.demix(self.match_frequency_pitch(mix), is_match_mix=True) if mdx_net_cut else self.match_frequency_pitch(mix)
514
- self.secondary_source = spec_utils.invert_stem(raw_mix, source) if self.is_invert_spec else mix.T-source.T
515
 
516
- self.secondary_source_map = self.final_process(secondary_stem_path, self.secondary_source, self.secondary_source_secondary, self.secondary_stem, samplerate)
517
-
518
- if not self.is_secondary_stem_only:
519
- primary_stem_path = os.path.join(self.export_path, f'{self.audio_file_base}_({self.primary_stem}).wav')
520
 
521
- if not isinstance(self.primary_source, np.ndarray):
522
- self.primary_source = source.T
523
-
524
- self.primary_source_map = self.final_process(primary_stem_path, self.primary_source, self.secondary_source_primary, self.primary_stem, samplerate)
525
-
526
- clear_gpu_cache()
527
 
 
528
  secondary_sources = {**self.primary_source_map, **self.secondary_source_map}
529
-
530
- self.process_vocal_split_chain(secondary_sources)
531
 
532
- if self.is_secondary_model or self.is_pre_proc_model:
 
 
533
  return secondary_sources
534
 
535
  def initialize_model_settings(self):
536
  self.n_bins = self.n_fft//2+1
537
  self.trim = self.n_fft//2
538
- self.chunk_size = self.hop * (self.mdx_segment_size-1)
 
 
539
  self.gen_size = self.chunk_size-2*self.trim
540
- self.stft = STFT(self.n_fft, self.hop, self.dim_f, self.device)
541
-
542
- def demix(self, mix, is_match_mix=False):
543
- self.initialize_model_settings()
544
-
545
- org_mix = mix
546
- tar_waves_ = []
547
 
548
- if is_match_mix:
549
- chunk_size = self.hop * (256-1)
550
- overlap = 0.02
 
 
 
551
  else:
552
- chunk_size = self.chunk_size
553
- overlap = self.overlap_mdx
554
-
555
- if self.is_pitch_change:
556
- mix, sr_pitched = spec_utils.change_pitch_semitones(mix, 44100, semitone_shift=-self.semitone_shift)
557
-
558
- gen_size = chunk_size-2*self.trim
559
-
560
- pad = gen_size + self.trim - ((mix.shape[-1]) % gen_size)
561
- mixture = np.concatenate((np.zeros((2, self.trim), dtype='float32'), mix, np.zeros((2, pad), dtype='float32')), 1)
562
-
563
- step = self.chunk_size - self.n_fft if overlap == DEFAULT else int((1 - overlap) * chunk_size)
564
- result = np.zeros((1, 2, mixture.shape[-1]), dtype=np.float32)
565
- divider = np.zeros((1, 2, mixture.shape[-1]), dtype=np.float32)
566
- total = 0
567
- total_chunks = (mixture.shape[-1] + step - 1) // step
568
-
569
- for i in range(0, mixture.shape[-1], step):
570
- total += 1
571
- start = i
572
- end = min(i + chunk_size, mixture.shape[-1])
573
-
574
- chunk_size_actual = end - start
575
-
576
- if overlap == 0:
577
- window = None
578
- else:
579
- window = np.hanning(chunk_size_actual)
580
- window = np.tile(window[None, None, :], (1, 2, 1))
581
-
582
- mix_part_ = mixture[:, start:end]
583
- if end != i + chunk_size:
584
- pad_size = (i + chunk_size) - end
585
- mix_part_ = np.concatenate((mix_part_, np.zeros((2, pad_size), dtype='float32')), axis=-1)
586
 
587
- mix_part = torch.tensor([mix_part_], dtype=torch.float32).to(self.device)
588
- mix_waves = mix_part.split(self.mdx_batch_size)
589
-
 
 
 
 
 
 
 
 
590
  with torch.no_grad():
591
  for mix_wave in mix_waves:
592
- self.running_inference_progress_bar(total_chunks, is_match_mix=is_match_mix)
593
-
594
- tar_waves = self.run_model(mix_wave, is_match_mix=is_match_mix)
595
-
596
- if window is not None:
597
- tar_waves[..., :chunk_size_actual] *= window
598
- divider[..., start:end] += window
599
- else:
600
- divider[..., start:end] += 1
601
-
602
- result[..., start:end] += tar_waves[..., :end-start]
603
-
604
- tar_waves = result / divider
605
- tar_waves_.append(tar_waves)
606
-
607
- tar_waves_ = np.vstack(tar_waves_)[:, :, self.trim:-self.trim]
608
- tar_waves = np.concatenate(tar_waves_, axis=-1)[:, :mix.shape[-1]]
609
 
610
- source = tar_waves[:,0:None]
611
-
612
- if self.is_pitch_change and not is_match_mix:
613
- source = self.pitch_fix(source, sr_pitched, org_mix)
614
-
615
- source = source if is_match_mix else source*self.compensate
616
-
617
- if self.is_denoise_model and not is_match_mix:
618
- if NO_STEM in self.primary_stem_native or self.primary_stem_native == INST_STEM:
619
- if org_mix.shape[1] != source.shape[1]:
620
- source = spec_utils.match_array_shapes(source, org_mix)
621
- source = org_mix - vr_denoiser(org_mix-source, self.device, model_path=self.DENOISER_MODEL)
622
- else:
623
- source = vr_denoiser(source, self.device, model_path=self.DENOISER_MODEL)
624
-
625
- return source
626
 
627
- def run_model(self, mix, is_match_mix=False):
628
 
629
  spek = self.stft(mix.to(self.device))*self.adjust
630
  spek[:, :, :3, :] *= 0
@@ -634,189 +385,58 @@ class SeperateMDX(SeperateAttributes):
634
  else:
635
  spec_pred = -self.model_run(-spek)*0.5+self.model_run(spek)*0.5 if self.is_denoise else self.model_run(spek)
636
 
637
- return self.stft.inverse(torch.tensor(spec_pred).to(self.device)).cpu().detach().numpy()
638
-
639
- class SeperateMDXC(SeperateAttributes):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
640
 
641
  def seperate(self):
642
- samplerate = 44100
643
- sources = None
644
-
645
- if self.primary_model_name == self.model_basename and isinstance(self.primary_sources, tuple):
646
- mix, sources = self.primary_sources
647
- self.load_cached_sources()
648
- else:
649
- self.start_inference_console_write()
650
- self.running_inference_console_write()
651
- mix = prepare_mix(self.audio_file)
652
- sources = self.demix(mix)
653
- if not self.is_vocal_split_model:
654
- self.cache_source((mix, sources))
655
- self.write_to_console(DONE, base_text='')
656
 
657
- stem_list = [self.mdx_c_configs.training.target_instrument] if self.mdx_c_configs.training.target_instrument else [i for i in self.mdx_c_configs.training.instruments]
658
-
659
- if self.is_secondary_model:
660
- if self.is_pre_proc_model:
661
- self.mdxnet_stem_select = stem_list[0]
662
- else:
663
- self.mdxnet_stem_select = self.main_model_primary_stem_4_stem if self.main_model_primary_stem_4_stem else self.primary_model_primary_stem
664
- self.primary_stem = self.mdxnet_stem_select
665
- self.secondary_stem = secondary_stem(self.mdxnet_stem_select)
666
- self.is_primary_stem_only, self.is_secondary_stem_only = False, False
667
-
668
- is_all_stems = self.mdxnet_stem_select == ALL_STEMS
669
- is_not_ensemble_master = not self.process_data['is_ensemble_master']
670
- is_not_single_stem = not len(stem_list) <= 2
671
- is_not_secondary_model = not self.is_secondary_model
672
- is_ensemble_4_stem = self.is_4_stem_ensemble and is_not_single_stem
673
-
674
- if (is_all_stems and is_not_ensemble_master and is_not_single_stem and is_not_secondary_model) or is_ensemble_4_stem and not self.is_pre_proc_model:
675
- for stem in stem_list:
676
- primary_stem_path = os.path.join(self.export_path, f'{self.audio_file_base}_({stem}).wav')
677
- self.primary_source = sources[stem].T
678
- self.write_audio(primary_stem_path, self.primary_source, samplerate, stem_name=stem)
679
-
680
- if stem == VOCAL_STEM and not self.is_sec_bv_rebalance:
681
- self.process_vocal_split_chain({VOCAL_STEM:stem})
682
- else:
683
- if len(stem_list) == 1:
684
- source_primary = sources
685
- else:
686
- source_primary = sources[stem_list[0]] if self.is_multi_stem_ensemble and len(stem_list) == 2 else sources[self.mdxnet_stem_select]
687
- if self.is_secondary_model_activated and self.secondary_model:
688
- self.secondary_source_primary, self.secondary_source_secondary = process_secondary_model(self.secondary_model,
689
- self.process_data,
690
- main_process_method=self.process_method,
691
- main_model_primary=self.primary_stem)
692
-
693
- if not self.is_primary_stem_only:
694
- secondary_stem_path = os.path.join(self.export_path, f'{self.audio_file_base}_({self.secondary_stem}).wav')
695
- if not isinstance(self.secondary_source, np.ndarray):
696
-
697
- if self.is_mdx_combine_stems and len(stem_list) >= 2:
698
- if len(stem_list) == 2:
699
- secondary_source = sources[self.secondary_stem]
700
- else:
701
- sources.pop(self.primary_stem)
702
- next_stem = next(iter(sources))
703
- secondary_source = np.zeros_like(sources[next_stem])
704
- for v in sources.values():
705
- secondary_source += v
706
-
707
- self.secondary_source = secondary_source.T
708
- else:
709
- self.secondary_source, raw_mix = source_primary, self.match_frequency_pitch(mix)
710
- self.secondary_source = spec_utils.to_shape(self.secondary_source, raw_mix.shape)
711
-
712
- if self.is_invert_spec:
713
- self.secondary_source = spec_utils.invert_stem(raw_mix, self.secondary_source)
714
- else:
715
- self.secondary_source = (-self.secondary_source.T+raw_mix.T)
716
-
717
- self.secondary_source_map = self.final_process(secondary_stem_path, self.secondary_source, self.secondary_source_secondary, self.secondary_stem, samplerate)
718
-
719
- if not self.is_secondary_stem_only:
720
- primary_stem_path = os.path.join(self.export_path, f'{self.audio_file_base}_({self.primary_stem}).wav')
721
- if not isinstance(self.primary_source, np.ndarray):
722
- self.primary_source = source_primary.T
723
-
724
- self.primary_source_map = self.final_process(primary_stem_path, self.primary_source, self.secondary_source_primary, self.primary_stem, samplerate)
725
-
726
- clear_gpu_cache()
727
-
728
- secondary_sources = {**self.primary_source_map, **self.secondary_source_map}
729
- self.process_vocal_split_chain(secondary_sources)
730
-
731
- if self.is_secondary_model or self.is_pre_proc_model:
732
- return secondary_sources
733
-
734
- def demix(self, mix):
735
- sr_pitched = 441000
736
- org_mix = mix
737
- if self.is_pitch_change:
738
- mix, sr_pitched = spec_utils.change_pitch_semitones(mix, 44100, semitone_shift=-self.semitone_shift)
739
-
740
- model = TFC_TDF_net(self.mdx_c_configs, device=self.device)
741
- model.load_state_dict(torch.load(self.model_path, map_location=cpu))
742
- model.to(self.device).eval()
743
- mix = torch.tensor(mix, dtype=torch.float32)
744
-
745
- try:
746
- S = model.num_target_instruments
747
- except Exception as e:
748
- S = model.module.num_target_instruments
749
-
750
- mdx_segment_size = self.mdx_c_configs.inference.dim_t if self.is_mdx_c_seg_def else self.mdx_segment_size
751
-
752
- batch_size = self.mdx_batch_size
753
- chunk_size = self.mdx_c_configs.audio.hop_length * (mdx_segment_size - 1)
754
- overlap = self.overlap_mdx23
755
-
756
- hop_size = chunk_size // overlap
757
- mix_shape = mix.shape[1]
758
- pad_size = hop_size - (mix_shape - chunk_size) % hop_size
759
- mix = torch.cat([torch.zeros(2, chunk_size - hop_size), mix, torch.zeros(2, pad_size + chunk_size - hop_size)], 1)
760
-
761
- chunks = mix.unfold(1, chunk_size, hop_size).transpose(0, 1)
762
- batches = [chunks[i : i + batch_size] for i in range(0, len(chunks), batch_size)]
763
-
764
- X = torch.zeros(S, *mix.shape) if S > 1 else torch.zeros_like(mix)
765
- X = X.to(self.device)
766
-
767
- with torch.no_grad():
768
- cnt = 0
769
- for batch in batches:
770
- self.running_inference_progress_bar(len(batches))
771
- x = model(batch.to(self.device))
772
-
773
- for w in x:
774
- X[..., cnt * hop_size : cnt * hop_size + chunk_size] += w
775
- cnt += 1
776
-
777
- estimated_sources = X[..., chunk_size - hop_size:-(pad_size + chunk_size - hop_size)] / overlap
778
- del X
779
- pitch_fix = lambda s:self.pitch_fix(s, sr_pitched, org_mix)
780
-
781
- if S > 1:
782
- sources = {k: pitch_fix(v) if self.is_pitch_change else v for k, v in zip(self.mdx_c_configs.training.instruments, estimated_sources.cpu().detach().numpy())}
783
- del estimated_sources
784
- if self.is_denoise_model:
785
- if VOCAL_STEM in sources.keys() and INST_STEM in sources.keys():
786
- sources[VOCAL_STEM] = vr_denoiser(sources[VOCAL_STEM], self.device, model_path=self.DENOISER_MODEL)
787
- if sources[VOCAL_STEM].shape[1] != org_mix.shape[1]:
788
- sources[VOCAL_STEM] = spec_utils.match_array_shapes(sources[VOCAL_STEM], org_mix)
789
- sources[INST_STEM] = org_mix - sources[VOCAL_STEM]
790
-
791
- return sources
792
- else:
793
- est_s = estimated_sources.cpu().detach().numpy()
794
- del estimated_sources
795
- return pitch_fix(est_s) if self.is_pitch_change else est_s
796
-
797
- class SeperateDemucs(SeperateAttributes):
798
- def seperate(self):
799
  samplerate = 44100
800
  source = None
801
  model_scale = None
802
  stem_source = None
803
  stem_source_secondary = None
804
  inst_mix = None
 
 
805
  inst_source = None
806
  is_no_write = False
807
  is_no_piano_guitar = False
808
- is_no_cache = False
809
-
810
- if self.primary_model_name == self.model_basename and isinstance(self.primary_sources, np.ndarray) and not self.pre_proc_model:
 
811
  source = self.primary_sources
812
- self.load_cached_sources()
813
  else:
814
  self.start_inference_console_write()
815
- is_no_cache = True
816
 
817
- mix = prepare_mix(self.audio_file)
818
-
819
- if is_no_cache:
 
 
820
  if self.demucs_version == DEMUCS_V1:
821
  if str(self.model_path).endswith(".gz"):
822
  self.model_path = gzip.open(self.model_path, "rb")
@@ -842,23 +462,26 @@ class SeperateDemucs(SeperateAttributes):
842
  is_no_write = True
843
  self.write_to_console(DONE, base_text='')
844
  mix_no_voc = process_secondary_model(self.pre_proc_model, self.process_data, is_pre_proc_model=True)
845
- inst_mix = prepare_mix(mix_no_voc[INST_STEM])
846
  self.process_iteration()
847
  self.running_inference_console_write(is_no_write=is_no_write)
848
  inst_source = self.demix_demucs(inst_mix)
 
849
  self.process_iteration()
850
 
851
  self.running_inference_console_write(is_no_write=is_no_write) if not self.pre_proc_model else None
 
852
 
853
  if self.primary_model_name == self.model_basename and isinstance(self.primary_sources, np.ndarray) and self.pre_proc_model:
854
  source = self.primary_sources
855
  else:
856
  source = self.demix_demucs(mix)
 
857
 
858
  self.write_to_console(DONE, base_text='')
859
 
860
  del self.demucs
861
- clear_gpu_cache()
862
 
863
  if isinstance(inst_source, np.ndarray):
864
  source_reshape = spec_utils.reshape_sources(inst_source[self.demucs_source_map[VOCAL_STEM]], source[self.demucs_source_map[VOCAL_STEM]])
@@ -866,7 +489,6 @@ class SeperateDemucs(SeperateAttributes):
866
  source = inst_source
867
 
868
  if isinstance(source, np.ndarray):
869
-
870
  if len(source) == 2:
871
  self.demucs_source_map = DEMUCS_2_SOURCE_MAPPER
872
  else:
@@ -881,40 +503,46 @@ class SeperateDemucs(SeperateAttributes):
881
  other_source += i
882
  source_reshape = spec_utils.reshape_sources(source[self.demucs_source_map[OTHER_STEM]], other_source)
883
  source[self.demucs_source_map[OTHER_STEM]] = source_reshape
884
-
885
- if not self.is_vocal_split_model:
886
  self.cache_source(source)
887
-
888
- if (self.demucs_stems == ALL_STEMS and not self.process_data['is_ensemble_master']) or self.is_4_stem_ensemble and not self.is_return_dual:
889
  for stem_name, stem_value in self.demucs_source_map.items():
890
  if self.is_secondary_model_activated and not self.is_secondary_model and not stem_value >= 4:
891
  if self.secondary_model_4_stem[stem_value]:
892
  model_scale = self.secondary_model_4_stem_scale[stem_value]
893
- stem_source_secondary = process_secondary_model(self.secondary_model_4_stem[stem_value], self.process_data, main_model_primary_stem_4_stem=stem_name, is_source_load=True, is_return_dual=False)
894
  if isinstance(stem_source_secondary, np.ndarray):
895
- stem_source_secondary = stem_source_secondary[1 if self.secondary_model_4_stem[stem_value].demucs_stem_count == 2 else stem_value].T
 
896
  elif type(stem_source_secondary) is dict:
897
  stem_source_secondary = stem_source_secondary[stem_name]
898
 
899
  stem_source_secondary = None if stem_value >= 4 else stem_source_secondary
 
900
  stem_path = os.path.join(self.export_path, f'{self.audio_file_base}_({stem_name}).wav')
901
- stem_source = source[stem_value].T
902
-
903
- stem_source = self.process_secondary_stem(stem_source, secondary_model_source=stem_source_secondary, model_scale=model_scale)
904
- self.write_audio(stem_path, stem_source, samplerate, stem_name=stem_name)
905
-
906
- if stem_name == VOCAL_STEM and not self.is_sec_bv_rebalance:
907
- self.process_vocal_split_chain({VOCAL_STEM:stem_source})
908
-
909
  if self.is_secondary_model:
910
  return source
911
  else:
912
- if self.is_secondary_model_activated and self.secondary_model:
 
913
  self.secondary_source_primary, self.secondary_source_secondary = process_secondary_model(self.secondary_model, self.process_data, main_process_method=self.process_method)
914
-
 
 
 
 
 
 
 
 
915
  if not self.is_primary_stem_only:
916
  def secondary_save(sec_stem_name, source, raw_mixture=None, is_inst_mixture=False):
917
  secondary_source = self.secondary_source if not is_inst_mixture else None
 
918
  secondary_stem_path = os.path.join(self.export_path, f'{self.audio_file_base}_({sec_stem_name}).wav')
919
  secondary_source_secondary = None
920
 
@@ -930,12 +558,12 @@ class SeperateDemucs(SeperateAttributes):
930
  secondary_source = np.zeros_like(source[0])
931
  for i in source:
932
  secondary_source += i
933
- secondary_source = secondary_source.T
934
  else:
935
  if not isinstance(raw_mixture, np.ndarray):
936
- raw_mixture = prepare_mix(self.audio_file)
937
 
938
- secondary_source = source[self.demucs_source_map[self.primary_stem]]
939
 
940
  if self.is_invert_spec:
941
  secondary_source = spec_utils.invert_stem(raw_mixture, secondary_source)
@@ -946,90 +574,86 @@ class SeperateDemucs(SeperateAttributes):
946
  if not is_inst_mixture:
947
  self.secondary_source = secondary_source
948
  secondary_source_secondary = self.secondary_source_secondary
949
- self.secondary_source = self.process_secondary_stem(secondary_source, secondary_source_secondary)
950
  self.secondary_source_map = {self.secondary_stem: self.secondary_source}
951
 
952
- self.write_audio(secondary_stem_path, secondary_source, samplerate, stem_name=sec_stem_name)
953
 
954
- secondary_save(self.secondary_stem, source, raw_mixture=mix)
955
 
956
  if self.is_demucs_pre_proc_model_inst_mix and self.pre_proc_model and not self.is_4_stem_ensemble:
957
- secondary_save(f"{self.secondary_stem} {INST_STEM}", source, raw_mixture=inst_mix, is_inst_mixture=True)
958
-
959
- if not self.is_secondary_stem_only:
960
- primary_stem_path = os.path.join(self.export_path, f'{self.audio_file_base}_({self.primary_stem}).wav')
961
- if not isinstance(self.primary_source, np.ndarray):
962
- self.primary_source = source[self.demucs_source_map[self.primary_stem]].T
963
-
964
- self.primary_source_map = self.final_process(primary_stem_path, self.primary_source, self.secondary_source_primary, self.primary_stem, samplerate)
965
 
966
  secondary_sources = {**self.primary_source_map, **self.secondary_source_map}
967
-
968
- self.process_vocal_split_chain(secondary_sources)
969
 
970
  if self.is_secondary_model:
971
  return secondary_sources
972
 
973
  def demix_demucs(self, mix):
974
-
975
- org_mix = mix
976
-
977
- if self.is_pitch_change:
978
- mix, sr_pitched = spec_utils.change_pitch_semitones(mix, 44100, semitone_shift=-self.semitone_shift)
979
-
980
  processed = {}
981
- mix = torch.tensor(mix, dtype=torch.float32)
982
- ref = mix.mean(0)
983
- mix = (mix - ref.mean()) / ref.std()
984
- mix_infer = mix
985
-
986
- with torch.no_grad():
987
- if self.demucs_version == DEMUCS_V1:
988
- sources = apply_model_v1(self.demucs,
989
- mix_infer.to(self.device),
990
- self.shifts,
991
- self.is_split_mode,
992
- set_progress_bar=self.set_progress_bar)
993
- elif self.demucs_version == DEMUCS_V2:
994
- sources = apply_model_v2(self.demucs,
995
- mix_infer.to(self.device),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
996
  self.shifts,
997
  self.is_split_mode,
998
  self.overlap,
999
- set_progress_bar=self.set_progress_bar)
1000
- else:
1001
- sources = apply_model(self.demucs,
1002
- mix_infer[None],
1003
- self.shifts,
1004
- self.is_split_mode,
1005
- self.overlap,
1006
- static_shifts=1 if self.shifts == 0 else self.shifts,
1007
- set_progress_bar=self.set_progress_bar,
1008
- device=self.device)[0]
1009
-
1010
- sources = (sources * ref.std() + ref.mean()).cpu().numpy()
1011
- sources[[0,1]] = sources[[1,0]]
1012
- processed[mix] = sources[:,:,0:None].copy()
1013
- sources = list(processed.values())
1014
- sources = [s[:,:,0:None] for s in sources]
1015
- #sources = [self.pitch_fix(s[:,:,0:None], sr_pitched, org_mix) if self.is_pitch_change else s[:,:,0:None] for s in sources]
1016
  sources = np.concatenate(sources, axis=-1)
1017
-
1018
- if self.is_pitch_change:
1019
- sources = np.stack([self.pitch_fix(stem, sr_pitched, org_mix) for stem in sources])
1020
 
1021
  return sources
1022
 
1023
  class SeperateVR(SeperateAttributes):
1024
 
1025
  def seperate(self):
1026
- if self.primary_model_name == self.model_basename and isinstance(self.primary_sources, tuple):
1027
- y_spec, v_spec = self.primary_sources
1028
- self.load_cached_sources()
1029
  else:
1030
  self.start_inference_console_write()
1031
-
1032
- device = self.device
 
 
 
 
 
1033
 
1034
  nn_arch_sizes = [
1035
  31191, # default
@@ -1039,11 +663,7 @@ class SeperateVR(SeperateAttributes):
1039
  nn_arch_size = min(nn_arch_sizes, key=lambda x:abs(x-model_size))
1040
 
1041
  if nn_arch_size in vr_5_1_models or self.is_vr_51_model:
1042
- self.model_run = nets_new.CascadedNet(self.mp.param['bins'] * 2,
1043
- nn_arch_size,
1044
- nout=self.model_capacity[0],
1045
- nout_lstm=self.model_capacity[1])
1046
- self.is_vr_51_model = True
1047
  else:
1048
  self.model_run = nets.determine_model_capacity(self.mp.param['bins'] * 2, nn_arch_size)
1049
 
@@ -1053,36 +673,41 @@ class SeperateVR(SeperateAttributes):
1053
  self.running_inference_console_write()
1054
 
1055
  y_spec, v_spec = self.inference_vr(self.loading_mix(), device, self.aggressiveness)
1056
- if not self.is_vocal_split_model:
1057
- self.cache_source((y_spec, v_spec))
1058
  self.write_to_console(DONE, base_text='')
1059
 
1060
- if self.is_secondary_model_activated and self.secondary_model:
1061
- self.secondary_source_primary, self.secondary_source_secondary = process_secondary_model(self.secondary_model, self.process_data, main_process_method=self.process_method, main_model_primary=self.primary_stem)
 
1062
 
1063
  if not self.is_secondary_stem_only:
 
1064
  primary_stem_path = os.path.join(self.export_path, f'{self.audio_file_base}_({self.primary_stem}).wav')
1065
  if not isinstance(self.primary_source, np.ndarray):
1066
- self.primary_source = self.spec_to_wav(y_spec).T
1067
  if not self.model_samplerate == 44100:
1068
  self.primary_source = librosa.resample(self.primary_source.T, orig_sr=self.model_samplerate, target_sr=44100).T
1069
 
1070
- self.primary_source_map = self.final_process(primary_stem_path, self.primary_source, self.secondary_source_primary, self.primary_stem, 44100)
 
 
1071
 
1072
  if not self.is_primary_stem_only:
 
1073
  secondary_stem_path = os.path.join(self.export_path, f'{self.audio_file_base}_({self.secondary_stem}).wav')
1074
  if not isinstance(self.secondary_source, np.ndarray):
1075
- self.secondary_source = self.spec_to_wav(v_spec).T
 
1076
  if not self.model_samplerate == 44100:
1077
  self.secondary_source = librosa.resample(self.secondary_source.T, orig_sr=self.model_samplerate, target_sr=44100).T
1078
 
1079
- self.secondary_source_map = self.final_process(secondary_stem_path, self.secondary_source, self.secondary_source_secondary, self.secondary_stem, 44100)
 
 
1080
 
1081
- clear_gpu_cache()
1082
  secondary_sources = {**self.primary_source_map, **self.secondary_source_map}
1083
-
1084
- self.process_vocal_split_chain(secondary_sources)
1085
-
1086
  if self.is_secondary_model:
1087
  return secondary_sources
1088
 
@@ -1092,9 +717,6 @@ class SeperateVR(SeperateAttributes):
1092
 
1093
  bands_n = len(self.mp.param['band'])
1094
 
1095
- audio_file = spec_utils.write_array_to_mem(self.audio_file, subtype=self.wav_type_set)
1096
- is_mp3 = audio_file.endswith('.mp3') if isinstance(audio_file, str) else False
1097
-
1098
  for d in range(bands_n, 0, -1):
1099
  bp = self.mp.param['band'][d]
1100
 
@@ -1104,25 +726,26 @@ class SeperateVR(SeperateAttributes):
1104
  wav_resolution = bp['res_type']
1105
 
1106
  if d == bands_n: # high-end band
1107
- X_wave[d], _ = librosa.load(audio_file, bp['sr'], False, dtype=np.float32, res_type=wav_resolution)
1108
- X_spec_s[d] = spec_utils.wave_to_spectrogram(X_wave[d], bp['hl'], bp['n_fft'], self.mp, band=d, is_v51_model=self.is_vr_51_model)
1109
 
1110
- if not np.any(X_wave[d]) and is_mp3:
1111
- X_wave[d] = rerun_mp3(audio_file, bp['sr'])
1112
 
1113
  if X_wave[d].ndim == 1:
1114
  X_wave[d] = np.asarray([X_wave[d], X_wave[d]])
1115
  else: # lower bands
1116
  X_wave[d] = librosa.resample(X_wave[d+1], self.mp.param['band'][d+1]['sr'], bp['sr'], res_type=wav_resolution)
1117
- X_spec_s[d] = spec_utils.wave_to_spectrogram(X_wave[d], bp['hl'], bp['n_fft'], self.mp, band=d, is_v51_model=self.is_vr_51_model)
1118
-
 
 
1119
  if d == bands_n and self.high_end_process != 'none':
1120
  self.input_high_end_h = (bp['n_fft']//2 - bp['crop_stop']) + (self.mp.param['pre_filter_stop'] - self.mp.param['pre_filter_start'])
1121
  self.input_high_end = X_spec_s[d][:, bp['n_fft']//2-self.input_high_end_h:bp['n_fft']//2, :]
1122
 
1123
- X_spec = spec_utils.combine_spectrograms(X_spec_s, self.mp, is_v51_model=self.is_vr_51_model)
1124
 
1125
- del X_wave, X_spec_s, audio_file
1126
 
1127
  return X_spec
1128
 
@@ -1160,6 +783,7 @@ class SeperateVR(SeperateAttributes):
1160
  return mask
1161
 
1162
  def postprocess(mask, X_mag, X_phase):
 
1163
  is_non_accom_stem = False
1164
  for stem in NON_ACCOM_STEMS:
1165
  if stem == self.primary_stem:
@@ -1174,7 +798,6 @@ class SeperateVR(SeperateAttributes):
1174
  v_spec = (1 - mask) * X_mag * np.exp(1.j * X_phase)
1175
 
1176
  return y_spec, v_spec
1177
-
1178
  X_mag, X_phase = spec_utils.preprocess(X_spec)
1179
  n_frame = X_mag.shape[2]
1180
  pad_l, pad_r, roi_size = spec_utils.make_padding(n_frame, self.window_size, self.model_run.offset)
@@ -1198,77 +821,35 @@ class SeperateVR(SeperateAttributes):
1198
  return y_spec, v_spec
1199
 
1200
  def spec_to_wav(self, spec):
1201
- if self.high_end_process.startswith('mirroring') and isinstance(self.input_high_end, np.ndarray) and self.input_high_end_h:
 
1202
  input_high_end_ = spec_utils.mirroring(self.high_end_process, spec, self.input_high_end, self.mp)
1203
- wav = spec_utils.cmb_spectrogram_to_wave(spec, self.mp, self.input_high_end_h, input_high_end_, is_v51_model=self.is_vr_51_model)
1204
  else:
1205
- wav = spec_utils.cmb_spectrogram_to_wave(spec, self.mp, is_v51_model=self.is_vr_51_model)
1206
 
1207
  return wav
1208
-
1209
- def process_secondary_model(secondary_model: ModelData,
1210
- process_data,
1211
- main_model_primary_stem_4_stem=None,
1212
- is_source_load=False,
1213
- main_process_method=None,
1214
- is_pre_proc_model=False,
1215
- is_return_dual=True,
1216
- main_model_primary=None):
1217
 
1218
  if not is_pre_proc_model:
1219
  process_iteration = process_data['process_iteration']
1220
  process_iteration()
1221
 
1222
  if secondary_model.process_method == VR_ARCH_TYPE:
1223
- seperator = SeperateVR(secondary_model, process_data, main_model_primary_stem_4_stem=main_model_primary_stem_4_stem, main_process_method=main_process_method, main_model_primary=main_model_primary)
1224
  if secondary_model.process_method == MDX_ARCH_TYPE:
1225
- if secondary_model.is_mdx_c:
1226
- seperator = SeperateMDXC(secondary_model, process_data, main_model_primary_stem_4_stem=main_model_primary_stem_4_stem, main_process_method=main_process_method, is_return_dual=is_return_dual, main_model_primary=main_model_primary)
1227
- else:
1228
- seperator = SeperateMDX(secondary_model, process_data, main_model_primary_stem_4_stem=main_model_primary_stem_4_stem, main_process_method=main_process_method, main_model_primary=main_model_primary)
1229
  if secondary_model.process_method == DEMUCS_ARCH_TYPE:
1230
- seperator = SeperateDemucs(secondary_model, process_data, main_model_primary_stem_4_stem=main_model_primary_stem_4_stem, main_process_method=main_process_method, is_return_dual=is_return_dual, main_model_primary=main_model_primary)
1231
 
1232
  secondary_sources = seperator.seperate()
1233
 
1234
- if type(secondary_sources) is dict and not is_source_load and not is_pre_proc_model:
1235
- return gather_sources(secondary_model.primary_model_primary_stem, secondary_stem(secondary_model.primary_model_primary_stem), secondary_sources)
1236
  else:
1237
  return secondary_sources
1238
 
1239
- def process_chain_model(secondary_model: ModelData,
1240
- process_data,
1241
- vocal_stem_path,
1242
- master_vocal_source,
1243
- master_inst_source=None):
1244
-
1245
- process_iteration = process_data['process_iteration']
1246
- process_iteration()
1247
-
1248
- if secondary_model.bv_model_rebalance:
1249
- vocal_source = spec_utils.reduce_mix_bv(master_inst_source, master_vocal_source, reduction_rate=secondary_model.bv_model_rebalance)
1250
- else:
1251
- vocal_source = master_vocal_source
1252
-
1253
- vocal_stem_path = [vocal_source, os.path.splitext(os.path.basename(vocal_stem_path))[0]]
1254
-
1255
- if secondary_model.process_method == VR_ARCH_TYPE:
1256
- seperator = SeperateVR(secondary_model, process_data, vocal_stem_path=vocal_stem_path, master_inst_source=master_inst_source, master_vocal_source=master_vocal_source)
1257
- if secondary_model.process_method == MDX_ARCH_TYPE:
1258
- if secondary_model.is_mdx_c:
1259
- seperator = SeperateMDXC(secondary_model, process_data, vocal_stem_path=vocal_stem_path, master_inst_source=master_inst_source, master_vocal_source=master_vocal_source)
1260
- else:
1261
- seperator = SeperateMDX(secondary_model, process_data, vocal_stem_path=vocal_stem_path, master_inst_source=master_inst_source, master_vocal_source=master_vocal_source)
1262
- if secondary_model.process_method == DEMUCS_ARCH_TYPE:
1263
- seperator = SeperateDemucs(secondary_model, process_data, vocal_stem_path=vocal_stem_path, master_inst_source=master_inst_source, master_vocal_source=master_vocal_source)
1264
-
1265
- secondary_sources = seperator.seperate()
1266
-
1267
- if type(secondary_sources) is dict:
1268
- return secondary_sources
1269
- else:
1270
- return None
1271
-
1272
  def gather_sources(primary_stem_name, secondary_stem_name, secondary_sources: dict):
1273
 
1274
  source_primary = False
@@ -1282,23 +863,53 @@ def gather_sources(primary_stem_name, secondary_stem_name, secondary_sources: di
1282
 
1283
  return source_primary, source_secondary
1284
 
1285
- def prepare_mix(mix):
1286
-
1287
  audio_path = mix
 
1288
 
1289
  if not isinstance(mix, np.ndarray):
1290
- mix, sr = librosa.load(mix, mono=False, sr=44100)
1291
  else:
1292
  mix = mix.T
1293
 
1294
- if isinstance(audio_path, str):
1295
- if not np.any(mix) and audio_path.endswith('.mp3'):
1296
- mix = rerun_mp3(audio_path)
1297
 
1298
  if mix.ndim == 1:
1299
  mix = np.asfortranarray([mix,mix])
1300
 
1301
- return mix
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1302
 
1303
  def rerun_mp3(audio_file, sample_rate=44100):
1304
 
@@ -1323,137 +934,9 @@ def save_format(audio_path, save_format, mp3_bit_set):
1323
 
1324
  if save_format == MP3:
1325
  audio_path_mp3 = audio_path.replace(".wav", ".mp3")
1326
- try:
1327
- musfile.export(audio_path_mp3, format="mp3", bitrate=mp3_bit_set, codec="libmp3lame")
1328
- except Exception as e:
1329
- print(e)
1330
- musfile.export(audio_path_mp3, format="mp3", bitrate=mp3_bit_set)
1331
 
1332
  try:
1333
  os.remove(audio_path)
1334
  except Exception as e:
1335
  print(e)
1336
-
1337
- def pitch_shift(mix):
1338
- new_sr = 31183
1339
-
1340
- # Resample audio file
1341
- resampled_audio = signal.resample_poly(mix, new_sr, 44100)
1342
-
1343
- return resampled_audio
1344
-
1345
- def list_to_dictionary(lst):
1346
- dictionary = {item: index for index, item in enumerate(lst)}
1347
- return dictionary
1348
-
1349
- def vr_denoiser(X, device, hop_length=1024, n_fft=2048, cropsize=256, is_deverber=False, model_path=None):
1350
- batchsize = 4
1351
-
1352
- if is_deverber:
1353
- nout, nout_lstm = 64, 128
1354
- mp = ModelParameters(os.path.join('lib_v5', 'vr_network', 'modelparams', '4band_v3.json'))
1355
- n_fft = mp.param['bins'] * 2
1356
- else:
1357
- mp = None
1358
- hop_length=1024
1359
- nout, nout_lstm = 16, 128
1360
-
1361
- model = nets_new.CascadedNet(n_fft, nout=nout, nout_lstm=nout_lstm)
1362
- model.load_state_dict(torch.load(model_path, map_location=cpu))
1363
- model.to(device)
1364
-
1365
- if mp is None:
1366
- X_spec = spec_utils.wave_to_spectrogram_old(X, hop_length, n_fft)
1367
- else:
1368
- X_spec = loading_mix(X.T, mp)
1369
-
1370
- #PreProcess
1371
- X_mag = np.abs(X_spec)
1372
- X_phase = np.angle(X_spec)
1373
-
1374
- #Sep
1375
- n_frame = X_mag.shape[2]
1376
- pad_l, pad_r, roi_size = spec_utils.make_padding(n_frame, cropsize, model.offset)
1377
- X_mag_pad = np.pad(X_mag, ((0, 0), (0, 0), (pad_l, pad_r)), mode='constant')
1378
- X_mag_pad /= X_mag_pad.max()
1379
-
1380
- X_dataset = []
1381
- patches = (X_mag_pad.shape[2] - 2 * model.offset) // roi_size
1382
- for i in range(patches):
1383
- start = i * roi_size
1384
- X_mag_crop = X_mag_pad[:, :, start:start + cropsize]
1385
- X_dataset.append(X_mag_crop)
1386
-
1387
- X_dataset = np.asarray(X_dataset)
1388
-
1389
- model.eval()
1390
-
1391
- with torch.no_grad():
1392
- mask = []
1393
- # To reduce the overhead, dataloader is not used.
1394
- for i in range(0, patches, batchsize):
1395
- X_batch = X_dataset[i: i + batchsize]
1396
- X_batch = torch.from_numpy(X_batch).to(device)
1397
-
1398
- pred = model.predict_mask(X_batch)
1399
-
1400
- pred = pred.detach().cpu().numpy()
1401
- pred = np.concatenate(pred, axis=2)
1402
- mask.append(pred)
1403
-
1404
- mask = np.concatenate(mask, axis=2)
1405
-
1406
- mask = mask[:, :, :n_frame]
1407
-
1408
- #Post Proc
1409
- if is_deverber:
1410
- v_spec = mask * X_mag * np.exp(1.j * X_phase)
1411
- y_spec = (1 - mask) * X_mag * np.exp(1.j * X_phase)
1412
- else:
1413
- v_spec = (1 - mask) * X_mag * np.exp(1.j * X_phase)
1414
-
1415
- if mp is None:
1416
- wave = spec_utils.spectrogram_to_wave_old(v_spec, hop_length=1024)
1417
- else:
1418
- wave = spec_utils.cmb_spectrogram_to_wave(v_spec, mp, is_v51_model=True).T
1419
-
1420
- wave = spec_utils.match_array_shapes(wave, X)
1421
-
1422
- if is_deverber:
1423
- wave_2 = spec_utils.cmb_spectrogram_to_wave(y_spec, mp, is_v51_model=True).T
1424
- wave_2 = spec_utils.match_array_shapes(wave_2, X)
1425
- return wave, wave_2
1426
- else:
1427
- return wave
1428
-
1429
- def loading_mix(X, mp):
1430
-
1431
- X_wave, X_spec_s = {}, {}
1432
-
1433
- bands_n = len(mp.param['band'])
1434
-
1435
- for d in range(bands_n, 0, -1):
1436
- bp = mp.param['band'][d]
1437
-
1438
- if OPERATING_SYSTEM == 'Darwin':
1439
- wav_resolution = 'polyphase' if SYSTEM_PROC == ARM or ARM in SYSTEM_ARCH else bp['res_type']
1440
- else:
1441
- wav_resolution = 'polyphase'#bp['res_type']
1442
-
1443
- if d == bands_n: # high-end band
1444
- X_wave[d] = X
1445
-
1446
- else: # lower bands
1447
- X_wave[d] = librosa.resample(X_wave[d+1], mp.param['band'][d+1]['sr'], bp['sr'], res_type=wav_resolution)
1448
-
1449
- X_spec_s[d] = spec_utils.wave_to_spectrogram(X_wave[d], bp['hl'], bp['n_fft'], mp, band=d, is_v51_model=True)
1450
-
1451
- # if d == bands_n and is_high_end_process:
1452
- # input_high_end_h = (bp['n_fft']//2 - bp['crop_stop']) + (mp.param['pre_filter_stop'] - mp.param['pre_filter_start'])
1453
- # input_high_end = X_spec_s[d][:, bp['n_fft']//2-input_high_end_h:bp['n_fft']//2, :]
1454
-
1455
- X_spec = spec_utils.combine_spectrograms(X_spec_s, mp)
1456
-
1457
- del X_wave, X_spec_s
1458
-
1459
- return X_spec
 
6
  from demucs.pretrained import get_model as _gm
7
  from demucs.utils import apply_model_v1
8
  from demucs.utils import apply_model_v2
 
9
  from lib_v5 import spec_utils
10
  from lib_v5.vr_network import nets
11
  from lib_v5.vr_network import nets_new
12
+ #from lib_v5.vr_network.model_param_init import ModelParameters
13
  from pathlib import Path
14
  from gui_data.constants import *
15
  from gui_data.error_handling import *
 
16
  import audioread
17
  import gzip
18
  import librosa
 
24
  import warnings
25
  import pydub
26
  import soundfile as sf
27
+ import traceback
28
  import lib_v5.mdxnet as MdxnetSet
29
+
 
 
 
 
 
30
  if TYPE_CHECKING:
31
  from UVR import ModelData
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  warnings.filterwarnings("ignore")
34
  cpu = torch.device('cpu')
35
 
36
  class SeperateAttributes:
37
+ def __init__(self, model_data: ModelData, process_data: dict, main_model_primary_stem_4_stem=None, main_process_method=None):
 
 
 
 
 
 
 
 
38
 
39
  self.list_all_models: list
40
  self.process_data = process_data
41
  self.progress_value = 0
42
  self.set_progress_bar = process_data['set_progress_bar']
43
  self.write_to_console = process_data['write_to_console']
44
+ self.audio_file = process_data['audio_file']
45
+ self.audio_file_base = process_data['audio_file_base']
 
 
 
 
 
46
  self.export_path = process_data['export_path']
47
  self.cached_source_callback = process_data['cached_source_callback']
48
  self.cached_model_source_holder = process_data['cached_model_source_holder']
49
  self.is_4_stem_ensemble = process_data['is_4_stem_ensemble']
50
  self.list_all_models = process_data['list_all_models']
51
  self.process_iteration = process_data['process_iteration']
 
 
 
 
 
 
 
 
 
 
 
52
  self.mixer_path = model_data.mixer_path
53
  self.model_samplerate = model_data.model_samplerate
54
  self.model_capacity = model_data.model_capacity
 
70
  self.is_ensemble_mode = model_data.is_ensemble_mode
71
  self.secondary_model = model_data.secondary_model #
72
  self.primary_model_primary_stem = model_data.primary_model_primary_stem
 
73
  self.primary_stem = model_data.primary_stem #
74
  self.secondary_stem = model_data.secondary_stem #
75
  self.is_invert_spec = model_data.is_invert_spec #
 
76
  self.is_mixer_mode = model_data.is_mixer_mode #
77
  self.secondary_model_scale = model_data.secondary_model_scale #
78
  self.is_demucs_pre_proc_model_inst_mix = model_data.is_demucs_pre_proc_model_inst_mix #
 
82
  self.secondary_source = None
83
  self.secondary_source_primary = None
84
  self.secondary_source_secondary = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
+ if not model_data.process_method == DEMUCS_ARCH_TYPE:
87
+ if process_data['is_ensemble_master'] and not self.is_4_stem_ensemble:
88
+ if not model_data.ensemble_primary_stem == self.primary_stem:
89
+ self.is_primary_stem_only, self.is_secondary_stem_only = self.is_secondary_stem_only, self.is_primary_stem_only
90
+
91
+ if self.is_secondary_model and not process_data['is_ensemble_master']:
92
+ if not self.primary_model_primary_stem == self.primary_stem and not main_model_primary_stem_4_stem:
93
+ self.is_primary_stem_only, self.is_secondary_stem_only = self.is_secondary_stem_only, self.is_primary_stem_only
94
+
95
+ if main_model_primary_stem_4_stem:
96
+ self.is_primary_stem_only = True if main_model_primary_stem_4_stem == self.primary_stem else False
97
+ self.is_secondary_stem_only = True if not main_model_primary_stem_4_stem == self.primary_stem else False
98
 
99
+ if self.is_pre_proc_model:
100
+ self.is_primary_stem_only = True if self.primary_stem == INST_STEM else False
101
+ self.is_secondary_stem_only = True if self.secondary_stem == INST_STEM else False
 
 
 
102
 
103
  if model_data.process_method == MDX_ARCH_TYPE:
104
  self.is_mdx_ckpt = model_data.is_mdx_ckpt
105
  self.primary_model_name, self.primary_sources = self.cached_source_callback(MDX_ARCH_TYPE, model_name=self.model_basename)
106
+ self.is_denoise = model_data.is_denoise
 
 
107
  self.mdx_batch_size = model_data.mdx_batch_size
108
  self.compensate = model_data.compensate
109
+ self.dim_f, self.dim_t = model_data.mdx_dim_f_set, 2**model_data.mdx_dim_t_set
 
 
 
 
 
 
 
 
 
110
  self.n_fft = model_data.mdx_n_fft_scale_set
111
  self.chunks = model_data.chunks
112
  self.margin = model_data.margin
113
  self.adjust = 1
114
  self.dim_c = 4
115
  self.hop = 1024
116
+
117
+ if self.is_gpu_conversion >= 0 and torch.cuda.is_available():
118
+ self.device, self.run_type = torch.device('cuda:0'), ['CUDAExecutionProvider']
119
+ else:
120
+ self.device, self.run_type = torch.device('cpu'), ['CPUExecutionProvider']
121
 
122
  if model_data.process_method == DEMUCS_ARCH_TYPE:
123
  self.demucs_stems = model_data.demucs_stems if not main_process_method in [MDX_ARCH_TYPE, VR_ARCH_TYPE] else None
124
  self.secondary_model_4_stem = model_data.secondary_model_4_stem
125
  self.secondary_model_4_stem_scale = model_data.secondary_model_4_stem_scale
126
+ self.primary_stem = model_data.ensemble_primary_stem if process_data['is_ensemble_master'] else model_data.primary_stem
127
+ self.secondary_stem = model_data.ensemble_secondary_stem if process_data['is_ensemble_master'] else model_data.secondary_stem
128
  self.is_chunk_demucs = model_data.is_chunk_demucs
129
  self.segment = model_data.segment
130
  self.demucs_version = model_data.demucs_version
 
133
  self.is_demucs_combine_stems = model_data.is_demucs_combine_stems
134
  self.demucs_stem_count = model_data.demucs_stem_count
135
  self.pre_proc_model = model_data.pre_proc_model
 
 
 
 
 
 
 
136
 
 
 
 
 
 
 
 
 
 
137
  if self.is_secondary_model and not process_data['is_ensemble_master']:
138
  if not self.demucs_stem_count == 2 and model_data.primary_model_primary_stem == INST_STEM:
139
  self.primary_stem = VOCAL_STEM
140
  self.secondary_stem = INST_STEM
141
  else:
142
  self.primary_stem = model_data.primary_model_primary_stem
143
+ self.secondary_stem = STEM_PAIR_MAPPER[self.primary_stem]
144
+
145
+ if self.is_chunk_demucs:
146
+ self.chunks_demucs = model_data.chunks_demucs
147
+ self.margin_demucs = model_data.margin_demucs
148
+ else:
149
+ self.chunks_demucs = 0
150
+ self.margin_demucs = 44100
151
+
152
  self.shifts = model_data.shifts
153
  self.is_split_mode = model_data.is_split_mode if not self.demucs_version == DEMUCS_V4 else True
154
+ self.overlap = model_data.overlap
155
  self.primary_model_name, self.primary_sources = self.cached_source_callback(DEMUCS_ARCH_TYPE, model_name=self.model_basename)
156
 
157
  if model_data.process_method == VR_ARCH_TYPE:
 
158
  self.primary_model_name, self.primary_sources = self.cached_source_callback(VR_ARCH_TYPE, model_name=self.model_basename)
159
  self.mp = model_data.vr_model_param
160
  self.high_end_process = model_data.is_high_end_process
 
164
  self.batch_size = model_data.batch_size
165
  self.window_size = model_data.window_size
166
  self.input_high_end_h = None
 
167
  self.post_process_threshold = model_data.post_process_threshold
168
  self.aggressiveness = {'value': model_data.aggression_setting,
169
  'split_bin': self.mp.param['band'][1]['crop_stop'],
170
  'aggr_correction': self.mp.param.get('aggr_correction')}
171
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  def start_inference_console_write(self):
173
+
174
+ if self.is_secondary_model and not self.is_pre_proc_model:
175
  self.write_to_console(INFERENCE_STEP_2_SEC(self.process_method, self.model_basename))
176
 
177
  if self.is_pre_proc_model:
178
  self.write_to_console(INFERENCE_STEP_2_PRE(self.process_method, self.model_basename))
 
 
 
179
 
180
  def running_inference_console_write(self, is_no_write=False):
181
+
182
  self.write_to_console(DONE, base_text='') if not is_no_write else None
183
  self.set_progress_bar(0.05) if not is_no_write else None
184
 
185
+ if self.is_secondary_model and not self.is_pre_proc_model:
186
  self.write_to_console(INFERENCE_STEP_1_SEC)
187
  elif self.is_pre_proc_model:
188
  self.write_to_console(INFERENCE_STEP_1_PRE)
 
 
189
  else:
190
  self.write_to_console(INFERENCE_STEP_1)
191
 
 
198
 
199
  self.set_progress_bar(0.1, (0.8/length*self.progress_value))
200
 
201
+ def load_cached_sources(self, is_4_stem_demucs=False):
202
 
203
  if self.is_secondary_model and not self.is_pre_proc_model:
204
  self.write_to_console(INFERENCE_STEP_2_SEC_CACHED_MODOEL(self.process_method, self.model_basename))
205
  elif self.is_pre_proc_model:
206
  self.write_to_console(INFERENCE_STEP_2_PRE_CACHED_MODOEL(self.process_method, self.model_basename))
207
  else:
208
+ self.write_to_console(INFERENCE_STEP_2_PRIMARY_CACHED)
209
+
210
+ if not is_4_stem_demucs:
211
+ primary_stem, secondary_stem = gather_sources(self.primary_stem, self.secondary_stem, self.primary_sources)
212
+
213
+ return primary_stem, secondary_stem
214
 
215
  def cache_source(self, secondary_sources):
216
 
 
225
 
226
  if self.process_method == DEMUCS_ARCH_TYPE:
227
  self.cached_model_source_holder(DEMUCS_ARCH_TYPE, secondary_sources, self.model_basename)
228
+
229
+ def write_audio(self, stem_path, stem_source, samplerate, secondary_model_source=None, model_scale=None):
230
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
  if not self.is_secondary_model:
232
+ if self.is_secondary_model_activated:
233
+ if isinstance(secondary_model_source, np.ndarray):
234
+ secondary_model_scale = model_scale if model_scale else self.secondary_model_scale
235
+ stem_source = spec_utils.average_dual_sources(stem_source, secondary_model_source, secondary_model_scale)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
 
237
+ sf.write(stem_path, stem_source, samplerate, subtype=self.wav_type_set)
238
+ save_format(stem_path, self.save_format, self.mp3_bit_set) if not self.is_ensemble_mode else None
239
 
 
240
  self.write_to_console(DONE, base_text='')
241
+ self.set_progress_bar(0.95)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
 
243
+ def run_mixer(self, mix, sources):
244
+ try:
245
+ if self.is_mixer_mode and len(sources) == 4:
246
+ mixer = MdxnetSet.Mixer(self.device, self.mixer_path).eval()
247
+ with torch.no_grad():
248
+ mix = torch.tensor(mix, dtype=torch.float32)
249
+ sources_ = torch.tensor(sources).detach()
250
+ x = torch.cat([sources_, mix.unsqueeze(0)], 0)
251
+ sources_ = mixer(x)
252
+ final_source = np.array(sources_)
253
+ else:
254
+ final_source = sources
255
+ except Exception as e:
256
+ error_name = f'{type(e).__name__}'
257
+ traceback_text = ''.join(traceback.format_tb(e.__traceback__))
258
+ message = f'{error_name}: "{e}"\n{traceback_text}"'
259
+ print('Mixer Failed: ', message)
260
+ final_source = sources
261
+
262
+ return final_source
263
 
264
  class SeperateMDX(SeperateAttributes):
265
 
266
  def seperate(self):
267
  samplerate = 44100
268
+
269
+ if self.primary_model_name == self.model_basename and self.primary_sources:
270
+ self.primary_source, self.secondary_source = self.load_cached_sources()
 
271
  else:
272
  self.start_inference_console_write()
273
 
 
277
  separator = MdxnetSet.ConvTDFNet(**model_params)
278
  self.model_run = separator.load_from_checkpoint(self.model_path).to(self.device).eval()
279
  else:
280
+ ort_ = ort.InferenceSession(self.model_path, providers=self.run_type)
281
+ self.model_run = lambda spek:ort_.run(None, {'input': spek.cpu().numpy()})[0]
 
 
 
 
282
 
283
+ self.initialize_model_settings()
284
  self.running_inference_console_write()
285
+ mdx_net_cut = True if self.primary_stem in MDX_NET_FREQ_CUT else False
286
+ mix, raw_mix, samplerate = prepare_mix(self.audio_file, self.chunks, self.margin, mdx_net_cut=mdx_net_cut)
287
+ source = self.demix_base(mix, is_ckpt=self.is_mdx_ckpt)[0]
 
 
 
288
  self.write_to_console(DONE, base_text='')
289
 
290
+ if self.is_secondary_model_activated:
291
+ if self.secondary_model:
292
+ self.secondary_source_primary, self.secondary_source_secondary = process_secondary_model(self.secondary_model, self.process_data, main_process_method=self.process_method)
 
293
 
294
+ if not self.is_secondary_stem_only:
295
+ self.write_to_console(f'{SAVING_STEM[0]}{self.primary_stem}{SAVING_STEM[1]}') if not self.is_secondary_model else None
296
+ primary_stem_path = os.path.join(self.export_path, f'{self.audio_file_base}_({self.primary_stem}).wav')
297
+ if not isinstance(self.primary_source, np.ndarray):
298
+ self.primary_source = spec_utils.normalize(source, self.is_normalization).T
299
+ self.primary_source_map = {self.primary_stem: self.primary_source}
300
+ self.write_audio(primary_stem_path, self.primary_source, samplerate, self.secondary_source_primary)
301
+
302
  if not self.is_primary_stem_only:
303
+ self.write_to_console(f'{SAVING_STEM[0]}{self.secondary_stem}{SAVING_STEM[1]}') if not self.is_secondary_model else None
304
  secondary_stem_path = os.path.join(self.export_path, f'{self.audio_file_base}_({self.secondary_stem}).wav')
305
  if not isinstance(self.secondary_source, np.ndarray):
306
+ raw_mix = self.demix_base(raw_mix, is_match_mix=True)[0] if mdx_net_cut else raw_mix
307
+ self.secondary_source, raw_mix = spec_utils.normalize_two_stem(source*self.compensate, raw_mix, self.is_normalization)
308
 
309
+ if self.is_invert_spec:
310
+ self.secondary_source = spec_utils.invert_stem(raw_mix, self.secondary_source)
311
+ else:
312
+ self.secondary_source = (-self.secondary_source.T+raw_mix.T)
313
 
314
+ self.secondary_source_map = {self.secondary_stem: self.secondary_source}
315
+ self.write_audio(secondary_stem_path, self.secondary_source, samplerate, self.secondary_source_secondary)
 
 
 
 
316
 
317
+ torch.cuda.empty_cache()
318
  secondary_sources = {**self.primary_source_map, **self.secondary_source_map}
 
 
319
 
320
+ self.cache_source(secondary_sources)
321
+
322
+ if self.is_secondary_model:
323
  return secondary_sources
324
 
325
  def initialize_model_settings(self):
326
  self.n_bins = self.n_fft//2+1
327
  self.trim = self.n_fft//2
328
+ self.chunk_size = self.hop * (self.dim_t-1)
329
+ self.window = torch.hann_window(window_length=self.n_fft, periodic=False).to(self.device)
330
+ self.freq_pad = torch.zeros([1, self.dim_c, self.n_bins-self.dim_f, self.dim_t]).to(self.device)
331
  self.gen_size = self.chunk_size-2*self.trim
 
 
 
 
 
 
 
332
 
333
+ def initialize_mix(self, mix, is_ckpt=False):
334
+ if is_ckpt:
335
+ pad = self.gen_size + self.trim - ((mix.shape[-1]) % self.gen_size)
336
+ mixture = np.concatenate((np.zeros((2, self.trim), dtype='float32'),mix, np.zeros((2, pad), dtype='float32')), 1)
337
+ num_chunks = mixture.shape[-1] // self.gen_size
338
+ mix_waves = [mixture[:, i * self.gen_size: i * self.gen_size + self.chunk_size] for i in range(num_chunks)]
339
  else:
340
+ mix_waves = []
341
+ n_sample = mix.shape[1]
342
+ pad = self.gen_size - n_sample%self.gen_size
343
+ mix_p = np.concatenate((np.zeros((2,self.trim)), mix, np.zeros((2,pad)), np.zeros((2,self.trim))), 1)
344
+ i = 0
345
+ while i < n_sample + pad:
346
+ waves = np.array(mix_p[:, i:i+self.chunk_size])
347
+ mix_waves.append(waves)
348
+ i += self.gen_size
349
+
350
+ mix_waves = torch.tensor(mix_waves, dtype=torch.float32).to(self.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
351
 
352
+ return mix_waves, pad
353
+
354
+ def demix_base(self, mix, is_ckpt=False, is_match_mix=False):
355
+ chunked_sources = []
356
+ for slice in mix:
357
+ sources = []
358
+ tar_waves_ = []
359
+ mix_p = mix[slice]
360
+ mix_waves, pad = self.initialize_mix(mix_p, is_ckpt=is_ckpt)
361
+ mix_waves = mix_waves.split(self.mdx_batch_size)
362
+ pad = mix_p.shape[-1] if is_ckpt else -pad
363
  with torch.no_grad():
364
  for mix_wave in mix_waves:
365
+ self.running_inference_progress_bar(len(mix)*len(mix_waves), is_match_mix=is_match_mix)
366
+ tar_waves = self.run_model(mix_wave, is_ckpt=is_ckpt, is_match_mix=is_match_mix)
367
+ tar_waves_.append(tar_waves)
368
+ tar_waves_ = np.vstack(tar_waves_)[:, :, self.trim:-self.trim] if is_ckpt else tar_waves_
369
+ tar_waves = np.concatenate(tar_waves_, axis=-1)[:, :pad]
370
+ start = 0 if slice == 0 else self.margin
371
+ end = None if slice == list(mix.keys())[::-1][0] or self.margin == 0 else -self.margin
372
+ sources.append(tar_waves[:,start:end]*(1/self.adjust))
373
+ chunked_sources.append(sources)
374
+ sources = np.concatenate(chunked_sources, axis=-1)
 
 
 
 
 
 
 
375
 
376
+ return sources
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
377
 
378
+ def run_model(self, mix, is_ckpt=False, is_match_mix=False):
379
 
380
  spek = self.stft(mix.to(self.device))*self.adjust
381
  spek[:, :, :3, :] *= 0
 
385
  else:
386
  spec_pred = -self.model_run(-spek)*0.5+self.model_run(spek)*0.5 if self.is_denoise else self.model_run(spek)
387
 
388
+ if is_ckpt:
389
+ return self.istft(spec_pred).cpu().detach().numpy()
390
+ else:
391
+ return self.istft(torch.tensor(spec_pred).to(self.device)).to(cpu)[:,:,self.trim:-self.trim].transpose(0,1).reshape(2, -1).numpy()
392
+
393
+ def stft(self, x):
394
+ x = x.reshape([-1, self.chunk_size])
395
+ x = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True,return_complex=True)
396
+ x=torch.view_as_real(x)
397
+ x = x.permute([0,3,1,2])
398
+ x = x.reshape([-1,2,2,self.n_bins,self.dim_t]).reshape([-1,self.dim_c,self.n_bins,self.dim_t])
399
+ return x[:,:,:self.dim_f]
400
+
401
+ def istft(self, x, freq_pad=None):
402
+ freq_pad = self.freq_pad.repeat([x.shape[0],1,1,1]) if freq_pad is None else freq_pad
403
+ x = torch.cat([x, freq_pad], -2)
404
+ x = x.reshape([-1,2,2,self.n_bins,self.dim_t]).reshape([-1,2,self.n_bins,self.dim_t])
405
+ x = x.permute([0,2,3,1])
406
+ x=x.contiguous()
407
+ x=torch.view_as_complex(x)
408
+ x = torch.istft(x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True)
409
+ return x.reshape([-1,2,self.chunk_size])
410
+
411
+ class SeperateDemucs(SeperateAttributes):
412
 
413
  def seperate(self):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
414
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
415
  samplerate = 44100
416
  source = None
417
  model_scale = None
418
  stem_source = None
419
  stem_source_secondary = None
420
  inst_mix = None
421
+ inst_raw_mix = None
422
+ raw_mix = None
423
  inst_source = None
424
  is_no_write = False
425
  is_no_piano_guitar = False
426
+
427
+ if self.primary_model_name == self.model_basename and type(self.primary_sources) is dict and not self.pre_proc_model:
428
+ self.primary_source, self.secondary_source = self.load_cached_sources()
429
+ elif self.primary_model_name == self.model_basename and isinstance(self.primary_sources, np.ndarray) and not self.pre_proc_model:
430
  source = self.primary_sources
431
+ self.load_cached_sources(is_4_stem_demucs=True)
432
  else:
433
  self.start_inference_console_write()
 
434
 
435
+ if self.is_gpu_conversion >= 0:
436
+ self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
437
+ else:
438
+ self.device = torch.device('cpu')
439
+
440
  if self.demucs_version == DEMUCS_V1:
441
  if str(self.model_path).endswith(".gz"):
442
  self.model_path = gzip.open(self.model_path, "rb")
 
462
  is_no_write = True
463
  self.write_to_console(DONE, base_text='')
464
  mix_no_voc = process_secondary_model(self.pre_proc_model, self.process_data, is_pre_proc_model=True)
465
+ inst_mix, inst_raw_mix, inst_samplerate = prepare_mix(mix_no_voc[INST_STEM], self.chunks_demucs, self.margin_demucs)
466
  self.process_iteration()
467
  self.running_inference_console_write(is_no_write=is_no_write)
468
  inst_source = self.demix_demucs(inst_mix)
469
+ inst_source = self.run_mixer(inst_raw_mix, inst_source)
470
  self.process_iteration()
471
 
472
  self.running_inference_console_write(is_no_write=is_no_write) if not self.pre_proc_model else None
473
+ mix, raw_mix, samplerate = prepare_mix(self.audio_file, self.chunks_demucs, self.margin_demucs)
474
 
475
  if self.primary_model_name == self.model_basename and isinstance(self.primary_sources, np.ndarray) and self.pre_proc_model:
476
  source = self.primary_sources
477
  else:
478
  source = self.demix_demucs(mix)
479
+ source = self.run_mixer(raw_mix, source)
480
 
481
  self.write_to_console(DONE, base_text='')
482
 
483
  del self.demucs
484
+ torch.cuda.empty_cache()
485
 
486
  if isinstance(inst_source, np.ndarray):
487
  source_reshape = spec_utils.reshape_sources(inst_source[self.demucs_source_map[VOCAL_STEM]], source[self.demucs_source_map[VOCAL_STEM]])
 
489
  source = inst_source
490
 
491
  if isinstance(source, np.ndarray):
 
492
  if len(source) == 2:
493
  self.demucs_source_map = DEMUCS_2_SOURCE_MAPPER
494
  else:
 
503
  other_source += i
504
  source_reshape = spec_utils.reshape_sources(source[self.demucs_source_map[OTHER_STEM]], other_source)
505
  source[self.demucs_source_map[OTHER_STEM]] = source_reshape
506
+
507
+ if (self.demucs_stems == ALL_STEMS and not self.process_data['is_ensemble_master']) or self.is_4_stem_ensemble:
508
  self.cache_source(source)
509
+
 
510
  for stem_name, stem_value in self.demucs_source_map.items():
511
  if self.is_secondary_model_activated and not self.is_secondary_model and not stem_value >= 4:
512
  if self.secondary_model_4_stem[stem_value]:
513
  model_scale = self.secondary_model_4_stem_scale[stem_value]
514
+ stem_source_secondary = process_secondary_model(self.secondary_model_4_stem[stem_value], self.process_data, main_model_primary_stem_4_stem=stem_name, is_4_stem_demucs=True)
515
  if isinstance(stem_source_secondary, np.ndarray):
516
+ stem_source_secondary = stem_source_secondary[1 if self.secondary_model_4_stem[stem_value].demucs_stem_count == 2 else stem_value]
517
+ stem_source_secondary = spec_utils.normalize(stem_source_secondary, self.is_normalization).T
518
  elif type(stem_source_secondary) is dict:
519
  stem_source_secondary = stem_source_secondary[stem_name]
520
 
521
  stem_source_secondary = None if stem_value >= 4 else stem_source_secondary
522
+ self.write_to_console(f'{SAVING_STEM[0]}{stem_name}{SAVING_STEM[1]}') if not self.is_secondary_model else None
523
  stem_path = os.path.join(self.export_path, f'{self.audio_file_base}_({stem_name}).wav')
524
+ stem_source = spec_utils.normalize(source[stem_value], self.is_normalization).T
525
+ self.write_audio(stem_path, stem_source, samplerate, secondary_model_source=stem_source_secondary, model_scale=model_scale)
526
+
 
 
 
 
 
527
  if self.is_secondary_model:
528
  return source
529
  else:
530
+ if self.is_secondary_model_activated:
531
+ if self.secondary_model:
532
  self.secondary_source_primary, self.secondary_source_secondary = process_secondary_model(self.secondary_model, self.process_data, main_process_method=self.process_method)
533
+
534
+ if not self.is_secondary_stem_only:
535
+ self.write_to_console(f'{SAVING_STEM[0]}{self.primary_stem}{SAVING_STEM[1]}') if not self.is_secondary_model else None
536
+ primary_stem_path = os.path.join(self.export_path, f'{self.audio_file_base}_({self.primary_stem}).wav')
537
+ if not isinstance(self.primary_source, np.ndarray):
538
+ self.primary_source = spec_utils.normalize(source[self.demucs_source_map[self.primary_stem]], self.is_normalization).T
539
+ self.primary_source_map = {self.primary_stem: self.primary_source}
540
+ self.write_audio(primary_stem_path, self.primary_source, samplerate, self.secondary_source_primary)
541
+
542
  if not self.is_primary_stem_only:
543
  def secondary_save(sec_stem_name, source, raw_mixture=None, is_inst_mixture=False):
544
  secondary_source = self.secondary_source if not is_inst_mixture else None
545
+ self.write_to_console(f'{SAVING_STEM[0]}{sec_stem_name}{SAVING_STEM[1]}') if not self.is_secondary_model else None
546
  secondary_stem_path = os.path.join(self.export_path, f'{self.audio_file_base}_({sec_stem_name}).wav')
547
  secondary_source_secondary = None
548
 
 
558
  secondary_source = np.zeros_like(source[0])
559
  for i in source:
560
  secondary_source += i
561
+ secondary_source = spec_utils.normalize(secondary_source, self.is_normalization).T
562
  else:
563
  if not isinstance(raw_mixture, np.ndarray):
564
+ raw_mixture = prepare_mix(self.audio_file, self.chunks_demucs, self.margin_demucs, is_missing_mix=True)
565
 
566
+ secondary_source, raw_mixture = spec_utils.normalize_two_stem(source[self.demucs_source_map[self.primary_stem]], raw_mixture, self.is_normalization)
567
 
568
  if self.is_invert_spec:
569
  secondary_source = spec_utils.invert_stem(raw_mixture, secondary_source)
 
574
  if not is_inst_mixture:
575
  self.secondary_source = secondary_source
576
  secondary_source_secondary = self.secondary_source_secondary
 
577
  self.secondary_source_map = {self.secondary_stem: self.secondary_source}
578
 
579
+ self.write_audio(secondary_stem_path, secondary_source, samplerate, secondary_source_secondary)
580
 
581
+ secondary_save(self.secondary_stem, source, raw_mixture=raw_mix)
582
 
583
  if self.is_demucs_pre_proc_model_inst_mix and self.pre_proc_model and not self.is_4_stem_ensemble:
584
+ secondary_save(f"{self.secondary_stem} {INST_STEM}", source, raw_mixture=inst_raw_mix, is_inst_mixture=True)
 
 
 
 
 
 
 
585
 
586
  secondary_sources = {**self.primary_source_map, **self.secondary_source_map}
587
+
588
+ self.cache_source(secondary_sources)
589
 
590
  if self.is_secondary_model:
591
  return secondary_sources
592
 
593
  def demix_demucs(self, mix):
 
 
 
 
 
 
594
  processed = {}
595
+
596
+ set_progress_bar = None if self.is_chunk_demucs else self.set_progress_bar
597
+
598
+ for nmix in mix:
599
+ self.progress_value += 1
600
+ self.set_progress_bar(0.1, (0.8/len(mix)*self.progress_value)) if self.is_chunk_demucs else None
601
+ cmix = mix[nmix]
602
+ cmix = torch.tensor(cmix, dtype=torch.float32)
603
+ ref = cmix.mean(0)
604
+ cmix = (cmix - ref.mean()) / ref.std()
605
+ mix_infer = cmix
606
+
607
+ with torch.no_grad():
608
+ if self.demucs_version == DEMUCS_V1:
609
+ sources = apply_model_v1(self.demucs,
610
+ mix_infer.to(self.device),
611
+ self.shifts,
612
+ self.is_split_mode,
613
+ set_progress_bar=set_progress_bar)
614
+ elif self.demucs_version == DEMUCS_V2:
615
+ sources = apply_model_v2(self.demucs,
616
+ mix_infer.to(self.device),
617
+ self.shifts,
618
+ self.is_split_mode,
619
+ self.overlap,
620
+ set_progress_bar=set_progress_bar)
621
+ else:
622
+ sources = apply_model(self.demucs,
623
+ mix_infer[None],
624
  self.shifts,
625
  self.is_split_mode,
626
  self.overlap,
627
+ static_shifts=1 if self.shifts == 0 else self.shifts,
628
+ set_progress_bar=set_progress_bar,
629
+ device=self.device)[0]
630
+
631
+ sources = (sources * ref.std() + ref.mean()).cpu().numpy()
632
+ sources[[0,1]] = sources[[1,0]]
633
+ start = 0 if nmix == 0 else self.margin_demucs
634
+ end = None if nmix == list(mix.keys())[::-1][0] else -self.margin_demucs
635
+ if self.margin_demucs == 0:
636
+ end = None
637
+ processed[nmix] = sources[:,:,start:end].copy()
638
+ sources = list(processed.values())
 
 
 
 
 
639
  sources = np.concatenate(sources, axis=-1)
 
 
 
640
 
641
  return sources
642
 
643
  class SeperateVR(SeperateAttributes):
644
 
645
  def seperate(self):
646
+ if self.primary_model_name == self.model_basename and self.primary_sources:
647
+ self.primary_source, self.secondary_source = self.load_cached_sources()
 
648
  else:
649
  self.start_inference_console_write()
650
+ if self.is_gpu_conversion >= 0:
651
+ if OPERATING_SYSTEM == 'Darwin':
652
+ device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
653
+ else:
654
+ device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
655
+ else:
656
+ device = torch.device('cpu')
657
 
658
  nn_arch_sizes = [
659
  31191, # default
 
663
  nn_arch_size = min(nn_arch_sizes, key=lambda x:abs(x-model_size))
664
 
665
  if nn_arch_size in vr_5_1_models or self.is_vr_51_model:
666
+ self.model_run = nets_new.CascadedNet(self.mp.param['bins'] * 2, nn_arch_size, nout=self.model_capacity[0], nout_lstm=self.model_capacity[1])
 
 
 
 
667
  else:
668
  self.model_run = nets.determine_model_capacity(self.mp.param['bins'] * 2, nn_arch_size)
669
 
 
673
  self.running_inference_console_write()
674
 
675
  y_spec, v_spec = self.inference_vr(self.loading_mix(), device, self.aggressiveness)
 
 
676
  self.write_to_console(DONE, base_text='')
677
 
678
+ if self.is_secondary_model_activated:
679
+ if self.secondary_model:
680
+ self.secondary_source_primary, self.secondary_source_secondary = process_secondary_model(self.secondary_model, self.process_data, main_process_method=self.process_method)
681
 
682
  if not self.is_secondary_stem_only:
683
+ self.write_to_console(f'{SAVING_STEM[0]}{self.primary_stem}{SAVING_STEM[1]}') if not self.is_secondary_model else None
684
  primary_stem_path = os.path.join(self.export_path, f'{self.audio_file_base}_({self.primary_stem}).wav')
685
  if not isinstance(self.primary_source, np.ndarray):
686
+ self.primary_source = spec_utils.normalize(self.spec_to_wav(y_spec), self.is_normalization).T
687
  if not self.model_samplerate == 44100:
688
  self.primary_source = librosa.resample(self.primary_source.T, orig_sr=self.model_samplerate, target_sr=44100).T
689
 
690
+ self.primary_source_map = {self.primary_stem: self.primary_source}
691
+
692
+ self.write_audio(primary_stem_path, self.primary_source, 44100, self.secondary_source_primary)
693
 
694
  if not self.is_primary_stem_only:
695
+ self.write_to_console(f'{SAVING_STEM[0]}{self.secondary_stem}{SAVING_STEM[1]}') if not self.is_secondary_model else None
696
  secondary_stem_path = os.path.join(self.export_path, f'{self.audio_file_base}_({self.secondary_stem}).wav')
697
  if not isinstance(self.secondary_source, np.ndarray):
698
+ self.secondary_source = self.spec_to_wav(v_spec)
699
+ self.secondary_source = spec_utils.normalize(self.spec_to_wav(v_spec), self.is_normalization).T
700
  if not self.model_samplerate == 44100:
701
  self.secondary_source = librosa.resample(self.secondary_source.T, orig_sr=self.model_samplerate, target_sr=44100).T
702
 
703
+ self.secondary_source_map = {self.secondary_stem: self.secondary_source}
704
+
705
+ self.write_audio(secondary_stem_path, self.secondary_source, 44100, self.secondary_source_secondary)
706
 
707
+ torch.cuda.empty_cache()
708
  secondary_sources = {**self.primary_source_map, **self.secondary_source_map}
709
+ self.cache_source(secondary_sources)
710
+
 
711
  if self.is_secondary_model:
712
  return secondary_sources
713
 
 
717
 
718
  bands_n = len(self.mp.param['band'])
719
 
 
 
 
720
  for d in range(bands_n, 0, -1):
721
  bp = self.mp.param['band'][d]
722
 
 
726
  wav_resolution = bp['res_type']
727
 
728
  if d == bands_n: # high-end band
729
+ X_wave[d], _ = librosa.load(self.audio_file, bp['sr'], False, dtype=np.float32, res_type=wav_resolution)
 
730
 
731
+ if not np.any(X_wave[d]) and self.audio_file.endswith('.mp3'):
732
+ X_wave[d] = rerun_mp3(self.audio_file, bp['sr'])
733
 
734
  if X_wave[d].ndim == 1:
735
  X_wave[d] = np.asarray([X_wave[d], X_wave[d]])
736
  else: # lower bands
737
  X_wave[d] = librosa.resample(X_wave[d+1], self.mp.param['band'][d+1]['sr'], bp['sr'], res_type=wav_resolution)
738
+
739
+ X_spec_s[d] = spec_utils.wave_to_spectrogram_mt(X_wave[d], bp['hl'], bp['n_fft'], self.mp.param['mid_side'],
740
+ self.mp.param['mid_side_b2'], self.mp.param['reverse'])
741
+
742
  if d == bands_n and self.high_end_process != 'none':
743
  self.input_high_end_h = (bp['n_fft']//2 - bp['crop_stop']) + (self.mp.param['pre_filter_stop'] - self.mp.param['pre_filter_start'])
744
  self.input_high_end = X_spec_s[d][:, bp['n_fft']//2-self.input_high_end_h:bp['n_fft']//2, :]
745
 
746
+ X_spec = spec_utils.combine_spectrograms(X_spec_s, self.mp)
747
 
748
+ del X_wave, X_spec_s
749
 
750
  return X_spec
751
 
 
783
  return mask
784
 
785
  def postprocess(mask, X_mag, X_phase):
786
+
787
  is_non_accom_stem = False
788
  for stem in NON_ACCOM_STEMS:
789
  if stem == self.primary_stem:
 
798
  v_spec = (1 - mask) * X_mag * np.exp(1.j * X_phase)
799
 
800
  return y_spec, v_spec
 
801
  X_mag, X_phase = spec_utils.preprocess(X_spec)
802
  n_frame = X_mag.shape[2]
803
  pad_l, pad_r, roi_size = spec_utils.make_padding(n_frame, self.window_size, self.model_run.offset)
 
821
  return y_spec, v_spec
822
 
823
  def spec_to_wav(self, spec):
824
+
825
+ if self.high_end_process.startswith('mirroring'):
826
  input_high_end_ = spec_utils.mirroring(self.high_end_process, spec, self.input_high_end, self.mp)
827
+ wav = spec_utils.cmb_spectrogram_to_wave(spec, self.mp, self.input_high_end_h, input_high_end_)
828
  else:
829
+ wav = spec_utils.cmb_spectrogram_to_wave(spec, self.mp)
830
 
831
  return wav
832
+
833
+ def process_secondary_model(secondary_model: ModelData, process_data, main_model_primary_stem_4_stem=None, is_4_stem_demucs=False, main_process_method=None, is_pre_proc_model=False):
 
 
 
 
 
 
 
834
 
835
  if not is_pre_proc_model:
836
  process_iteration = process_data['process_iteration']
837
  process_iteration()
838
 
839
  if secondary_model.process_method == VR_ARCH_TYPE:
840
+ seperator = SeperateVR(secondary_model, process_data, main_model_primary_stem_4_stem=main_model_primary_stem_4_stem, main_process_method=main_process_method)
841
  if secondary_model.process_method == MDX_ARCH_TYPE:
842
+ seperator = SeperateMDX(secondary_model, process_data, main_model_primary_stem_4_stem=main_model_primary_stem_4_stem, main_process_method=main_process_method)
 
 
 
843
  if secondary_model.process_method == DEMUCS_ARCH_TYPE:
844
+ seperator = SeperateDemucs(secondary_model, process_data, main_model_primary_stem_4_stem=main_model_primary_stem_4_stem, main_process_method=main_process_method)
845
 
846
  secondary_sources = seperator.seperate()
847
 
848
+ if type(secondary_sources) is dict and not is_4_stem_demucs and not is_pre_proc_model:
849
+ return gather_sources(secondary_model.primary_model_primary_stem, STEM_PAIR_MAPPER[secondary_model.primary_model_primary_stem], secondary_sources)
850
  else:
851
  return secondary_sources
852
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
853
  def gather_sources(primary_stem_name, secondary_stem_name, secondary_sources: dict):
854
 
855
  source_primary = False
 
863
 
864
  return source_primary, source_secondary
865
 
866
+ def prepare_mix(mix, chunk_set, margin_set, mdx_net_cut=False, is_missing_mix=False):
867
+
868
  audio_path = mix
869
+ samplerate = 44100
870
 
871
  if not isinstance(mix, np.ndarray):
872
+ mix, samplerate = librosa.load(mix, mono=False, sr=44100)
873
  else:
874
  mix = mix.T
875
 
876
+ if not np.any(mix) and audio_path.endswith('.mp3'):
877
+ mix = rerun_mp3(audio_path)
 
878
 
879
  if mix.ndim == 1:
880
  mix = np.asfortranarray([mix,mix])
881
 
882
+ def get_segmented_mix(chunk_set=chunk_set):
883
+ segmented_mix = {}
884
+
885
+ samples = mix.shape[-1]
886
+ margin = margin_set
887
+ chunk_size = chunk_set*44100
888
+ assert not margin == 0, 'margin cannot be zero!'
889
+
890
+ if margin > chunk_size:
891
+ margin = chunk_size
892
+ if chunk_set == 0 or samples < chunk_size:
893
+ chunk_size = samples
894
+
895
+ counter = -1
896
+ for skip in range(0, samples, chunk_size):
897
+ counter+=1
898
+ s_margin = 0 if counter == 0 else margin
899
+ end = min(skip+chunk_size+margin, samples)
900
+ start = skip-s_margin
901
+ segmented_mix[skip] = mix[:,start:end].copy()
902
+ if end == samples:
903
+ break
904
+
905
+ return segmented_mix
906
+
907
+ if is_missing_mix:
908
+ return mix
909
+ else:
910
+ segmented_mix = get_segmented_mix()
911
+ raw_mix = get_segmented_mix(chunk_set=0) if mdx_net_cut else mix
912
+ return segmented_mix, raw_mix, samplerate
913
 
914
  def rerun_mp3(audio_file, sample_rate=44100):
915
 
 
934
 
935
  if save_format == MP3:
936
  audio_path_mp3 = audio_path.replace(".wav", ".mp3")
937
+ musfile.export(audio_path_mp3, format="mp3", bitrate=mp3_bit_set)
 
 
 
 
938
 
939
  try:
940
  os.remove(audio_path)
941
  except Exception as e:
942
  print(e)