helloWorld199 commited on
Commit
111965c
1 Parent(s): c5ea3da

Update src/mdx.py

Browse files
Files changed (1) hide show
  1. src/mdx.py +3 -3
src/mdx.py CHANGED
@@ -239,7 +239,7 @@ class MDX:
239
  return self.segment(processed_batches, True, chunk)
240
 
241
 
242
- def run_mdx(model_params, output_dir, model_path, filename, exclude_main=False, exclude_inversion=False, suffix=None, invert_suffix=None, denoise=False, keep_orig=True, m_threads=2):
243
  device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
244
 
245
  #device_properties = torch.cuda.get_device_properties(device)
@@ -274,14 +274,14 @@ def run_mdx(model_params, output_dir, model_path, filename, exclude_main=False,
274
 
275
  main_filepath = None
276
  if not exclude_main:
277
- main_filepath = os.path.join(output_dir, f"{os.path.basename(os.path.splitext(filename)[0])}_{stem_name}.wav")
278
  sf.write(main_filepath, wave_processed.T, sr)
279
 
280
  invert_filepath = None
281
  if not exclude_inversion:
282
  diff_stem_name = stem_naming.get(stem_name) if invert_suffix is None else invert_suffix
283
  stem_name = f"{stem_name}_diff" if diff_stem_name is None else diff_stem_name
284
- invert_filepath = os.path.join(output_dir, f"{os.path.basename(os.path.splitext(filename)[0])}_{stem_name}.wav")
285
  sf.write(invert_filepath, (-wave_processed.T * model.compensation) + wave.T, sr)
286
 
287
  if not keep_orig:
 
239
  return self.segment(processed_batches, True, chunk)
240
 
241
 
242
+ def run_mdx(model_params, output_dir, model_path, filename, exclude_main=False, exclude_inversion=False, suffix=None, invert_suffix=None, denoise=False, keep_orig=True, m_threads=2, _stemname1="", _stemname2=""):
243
  device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
244
 
245
  #device_properties = torch.cuda.get_device_properties(device)
 
274
 
275
  main_filepath = None
276
  if not exclude_main:
277
+ main_filepath = os.path.join(output_dir, f"{os.path.basename(os.path.splitext(filename)[0])}_{stem_name}{_stemname1}.wav")
278
  sf.write(main_filepath, wave_processed.T, sr)
279
 
280
  invert_filepath = None
281
  if not exclude_inversion:
282
  diff_stem_name = stem_naming.get(stem_name) if invert_suffix is None else invert_suffix
283
  stem_name = f"{stem_name}_diff" if diff_stem_name is None else diff_stem_name
284
+ invert_filepath = os.path.join(output_dir, f"{os.path.basename(os.path.splitext(filename)[0])}_{stem_name}{_stemname2}.wav")
285
  sf.write(invert_filepath, (-wave_processed.T * model.compensation) + wave.T, sr)
286
 
287
  if not keep_orig: