jhtonyKoo commited on
Commit
66e10e8
1 Parent(s): f528768

Update inference/style_transfer.py

Browse files
Files changed (1) hide show
  1. inference/style_transfer.py +10 -17
inference/style_transfer.py CHANGED
@@ -26,7 +26,7 @@ from data_loader import *
26
 
27
  class Mixing_Style_Transfer_Inference:
28
  def __init__(self, args, trained_w_ddp=True):
29
- if args.inference_device!='cpu' and torch.cuda.is_available():
30
  self.device = torch.device("cuda:0")
31
  else:
32
  self.device = torch.device("cpu")
@@ -86,7 +86,7 @@ class Mixing_Style_Transfer_Inference:
86
  if os.path.exists(os.path.join(cur_sep_output_dir, self.args.separation_model, cur_file_name, 'drums.wav')):
87
  print(f'\talready separated current file : {cur_sep_file_path}')
88
  else:
89
- cur_cmd_line = f"demucs {cur_sep_file_path} -n {self.args.separation_model} -d {self.args.separation_device} -o {cur_sep_output_dir}"
90
  os.system(cur_cmd_line)
91
 
92
 
@@ -109,7 +109,7 @@ class Mixing_Style_Transfer_Inference:
109
 
110
 
111
  # Inference whole song
112
- def inference(self, ):
113
  print("\n======= Start to inference music mixing style transfer =======")
114
  # normalized input
115
  output_name_tag = 'output' if self.args.normalize_input else 'output_notnormed'
@@ -267,7 +267,10 @@ class Mixing_Style_Transfer_Inference:
267
  sf.write(os.path.join(cur_out_dir, f"{cur_inst_name}_{output_name_tag}.wav"), fin_data_out_inst.transpose(-1, -2), self.args.sample_rate, 'PCM_16')
268
  # remix
269
  fin_data_out_mix = sum(inst_outputs)
270
- sf.write(os.path.join(cur_out_dir, f"mixture_{output_name_tag}.wav"), fin_data_out_mix.transpose(-1, -2), self.args.sample_rate, 'PCM_16')
 
 
 
271
 
272
 
273
  # function that segmentize an entire song into batch
@@ -322,7 +325,7 @@ class Mixing_Style_Transfer_Inference:
322
 
323
 
324
 
325
- if __name__ == '__main__':
326
  os.environ['MASTER_ADDR'] = '127.0.0.1'
327
  os.environ["CUDA_VISIBLE_DEVICES"] = '0'
328
  os.environ['MASTER_PORT'] = '8888'
@@ -366,7 +369,7 @@ if __name__ == '__main__':
366
  inference_args.add_argument('--stem_level_directory_name', type=str, default='separated')
367
  inference_args.add_argument('--save_each_inst', type=str2bool, default=False)
368
  inference_args.add_argument('--do_not_separate', type=str2bool, default=False)
369
- inference_args.add_argument('--separation_model', type=str, default='mdx_extra')
370
  # FX normalization
371
  inference_args.add_argument('--normalize_input', type=str2bool, default=True)
372
  inference_args.add_argument('--normalization_order', type=str2bool, default=['loudness', 'eq', 'compression', 'imager', 'loudness']) # Effects to be normalized, order matters
@@ -376,9 +379,7 @@ if __name__ == '__main__':
376
 
377
  device_args = parser.add_argument_group('Device args')
378
  device_args.add_argument('--workers', type=int, default=1)
379
- device_args.add_argument('--inference_device', type=str, default='gpu', help="if this option is not set to 'cpu', inference will happen on gpu only if there is a detected one")
380
  device_args.add_argument('--batch_size', type=int, default=1) # for processing long audio
381
- device_args.add_argument('--separation_device', type=str, default='cpu', help="device for performing source separation using Demucs")
382
 
383
  args = parser.parse_args()
384
 
@@ -388,13 +389,5 @@ if __name__ == '__main__':
388
  args.cfg_encoder = configs['Effects_Encoder']['default']
389
  args.cfg_converter = configs['TCN']['default']
390
 
 
391
 
392
- # Perform music mixing style transfer
393
- inference_style_transfer = Mixing_Style_Transfer_Inference(args)
394
- if args.interpolation:
395
- inference_style_transfer.inference_interpolation()
396
- else:
397
- inference_style_transfer.inference()
398
-
399
-
400
-
 
26
 
27
  class Mixing_Style_Transfer_Inference:
28
  def __init__(self, args, trained_w_ddp=True):
29
+ if torch.cuda.is_available():
30
  self.device = torch.device("cuda:0")
31
  else:
32
  self.device = torch.device("cpu")
 
86
  if os.path.exists(os.path.join(cur_sep_output_dir, self.args.separation_model, cur_file_name, 'drums.wav')):
87
  print(f'\talready separated current file : {cur_sep_file_path}')
88
  else:
89
+ cur_cmd_line = f"demucs {cur_sep_file_path} -n {self.args.separation_model} -d {self.device} -o {cur_sep_output_dir}"
90
  os.system(cur_cmd_line)
91
 
92
 
 
109
 
110
 
111
  # Inference whole song
112
+ def inference(self, input_track_path, reference_track_path):
113
  print("\n======= Start to inference music mixing style transfer =======")
114
  # normalized input
115
  output_name_tag = 'output' if self.args.normalize_input else 'output_notnormed'
 
267
  sf.write(os.path.join(cur_out_dir, f"{cur_inst_name}_{output_name_tag}.wav"), fin_data_out_inst.transpose(-1, -2), self.args.sample_rate, 'PCM_16')
268
  # remix
269
  fin_data_out_mix = sum(inst_outputs)
270
+ fin_output_path = os.path.join(cur_out_dir, f"mixture_{output_name_tag}.wav"
271
+ sf.write(fin_output_path), fin_data_out_mix.transpose(-1, -2), self.args.sample_rate, 'PCM_16')
272
+
273
+ return fin_output_path
274
 
275
 
276
  # function that segmentize an entire song into batch
 
325
 
326
 
327
 
328
+ def set_up()
329
  os.environ['MASTER_ADDR'] = '127.0.0.1'
330
  os.environ["CUDA_VISIBLE_DEVICES"] = '0'
331
  os.environ['MASTER_PORT'] = '8888'
 
369
  inference_args.add_argument('--stem_level_directory_name', type=str, default='separated')
370
  inference_args.add_argument('--save_each_inst', type=str2bool, default=False)
371
  inference_args.add_argument('--do_not_separate', type=str2bool, default=False)
372
+ inference_args.add_argument('--separation_model', type=str, default='htdemucs')
373
  # FX normalization
374
  inference_args.add_argument('--normalize_input', type=str2bool, default=True)
375
  inference_args.add_argument('--normalization_order', type=str2bool, default=['loudness', 'eq', 'compression', 'imager', 'loudness']) # Effects to be normalized, order matters
 
379
 
380
  device_args = parser.add_argument_group('Device args')
381
  device_args.add_argument('--workers', type=int, default=1)
 
382
  device_args.add_argument('--batch_size', type=int, default=1) # for processing long audio
 
383
 
384
  args = parser.parse_args()
385
 
 
389
  args.cfg_encoder = configs['Effects_Encoder']['default']
390
  args.cfg_converter = configs['TCN']['default']
391
 
392
+ return args
393