Update inference/style_transfer.py
Browse files- inference/style_transfer.py +10 -17
inference/style_transfer.py
CHANGED
@@ -26,7 +26,7 @@ from data_loader import *
|
|
26 |
|
27 |
class Mixing_Style_Transfer_Inference:
|
28 |
def __init__(self, args, trained_w_ddp=True):
|
29 |
-
if
|
30 |
self.device = torch.device("cuda:0")
|
31 |
else:
|
32 |
self.device = torch.device("cpu")
|
@@ -86,7 +86,7 @@ class Mixing_Style_Transfer_Inference:
|
|
86 |
if os.path.exists(os.path.join(cur_sep_output_dir, self.args.separation_model, cur_file_name, 'drums.wav')):
|
87 |
print(f'\talready separated current file : {cur_sep_file_path}')
|
88 |
else:
|
89 |
-
cur_cmd_line = f"demucs {cur_sep_file_path} -n {self.args.separation_model} -d {self.
|
90 |
os.system(cur_cmd_line)
|
91 |
|
92 |
|
@@ -109,7 +109,7 @@ class Mixing_Style_Transfer_Inference:
|
|
109 |
|
110 |
|
111 |
# Inference whole song
|
112 |
-
def inference(self, ):
|
113 |
print("\n======= Start to inference music mixing style transfer =======")
|
114 |
# normalized input
|
115 |
output_name_tag = 'output' if self.args.normalize_input else 'output_notnormed'
|
@@ -267,7 +267,10 @@ class Mixing_Style_Transfer_Inference:
|
|
267 |
sf.write(os.path.join(cur_out_dir, f"{cur_inst_name}_{output_name_tag}.wav"), fin_data_out_inst.transpose(-1, -2), self.args.sample_rate, 'PCM_16')
|
268 |
# remix
|
269 |
fin_data_out_mix = sum(inst_outputs)
|
270 |
-
|
|
|
|
|
|
|
271 |
|
272 |
|
273 |
# function that segmentize an entire song into batch
|
@@ -322,7 +325,7 @@ class Mixing_Style_Transfer_Inference:
|
|
322 |
|
323 |
|
324 |
|
325 |
-
|
326 |
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
327 |
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
|
328 |
os.environ['MASTER_PORT'] = '8888'
|
@@ -366,7 +369,7 @@ if __name__ == '__main__':
|
|
366 |
inference_args.add_argument('--stem_level_directory_name', type=str, default='separated')
|
367 |
inference_args.add_argument('--save_each_inst', type=str2bool, default=False)
|
368 |
inference_args.add_argument('--do_not_separate', type=str2bool, default=False)
|
369 |
-
inference_args.add_argument('--separation_model', type=str, default='
|
370 |
# FX normalization
|
371 |
inference_args.add_argument('--normalize_input', type=str2bool, default=True)
|
372 |
inference_args.add_argument('--normalization_order', type=str2bool, default=['loudness', 'eq', 'compression', 'imager', 'loudness']) # Effects to be normalized, order matters
|
@@ -376,9 +379,7 @@ if __name__ == '__main__':
|
|
376 |
|
377 |
device_args = parser.add_argument_group('Device args')
|
378 |
device_args.add_argument('--workers', type=int, default=1)
|
379 |
-
device_args.add_argument('--inference_device', type=str, default='gpu', help="if this option is not set to 'cpu', inference will happen on gpu only if there is a detected one")
|
380 |
device_args.add_argument('--batch_size', type=int, default=1) # for processing long audio
|
381 |
-
device_args.add_argument('--separation_device', type=str, default='cpu', help="device for performing source separation using Demucs")
|
382 |
|
383 |
args = parser.parse_args()
|
384 |
|
@@ -388,13 +389,5 @@ if __name__ == '__main__':
|
|
388 |
args.cfg_encoder = configs['Effects_Encoder']['default']
|
389 |
args.cfg_converter = configs['TCN']['default']
|
390 |
|
|
|
391 |
|
392 |
-
# Perform music mixing style transfer
|
393 |
-
inference_style_transfer = Mixing_Style_Transfer_Inference(args)
|
394 |
-
if args.interpolation:
|
395 |
-
inference_style_transfer.inference_interpolation()
|
396 |
-
else:
|
397 |
-
inference_style_transfer.inference()
|
398 |
-
|
399 |
-
|
400 |
-
|
|
|
26 |
|
27 |
class Mixing_Style_Transfer_Inference:
|
28 |
def __init__(self, args, trained_w_ddp=True):
|
29 |
+
if torch.cuda.is_available():
|
30 |
self.device = torch.device("cuda:0")
|
31 |
else:
|
32 |
self.device = torch.device("cpu")
|
|
|
86 |
if os.path.exists(os.path.join(cur_sep_output_dir, self.args.separation_model, cur_file_name, 'drums.wav')):
|
87 |
print(f'\talready separated current file : {cur_sep_file_path}')
|
88 |
else:
|
89 |
+
cur_cmd_line = f"demucs {cur_sep_file_path} -n {self.args.separation_model} -d {self.device} -o {cur_sep_output_dir}"
|
90 |
os.system(cur_cmd_line)
|
91 |
|
92 |
|
|
|
109 |
|
110 |
|
111 |
# Inference whole song
|
112 |
+
def inference(self, input_track_path, reference_track_path):
|
113 |
print("\n======= Start to inference music mixing style transfer =======")
|
114 |
# normalized input
|
115 |
output_name_tag = 'output' if self.args.normalize_input else 'output_notnormed'
|
|
|
267 |
sf.write(os.path.join(cur_out_dir, f"{cur_inst_name}_{output_name_tag}.wav"), fin_data_out_inst.transpose(-1, -2), self.args.sample_rate, 'PCM_16')
|
268 |
# remix
|
269 |
fin_data_out_mix = sum(inst_outputs)
|
270 |
+
fin_output_path = os.path.join(cur_out_dir, f"mixture_{output_name_tag}.wav"
|
271 |
+
sf.write(fin_output_path), fin_data_out_mix.transpose(-1, -2), self.args.sample_rate, 'PCM_16')
|
272 |
+
|
273 |
+
return fin_output_path
|
274 |
|
275 |
|
276 |
# function that segmentize an entire song into batch
|
|
|
325 |
|
326 |
|
327 |
|
328 |
+
def set_up()
|
329 |
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
330 |
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
|
331 |
os.environ['MASTER_PORT'] = '8888'
|
|
|
369 |
inference_args.add_argument('--stem_level_directory_name', type=str, default='separated')
|
370 |
inference_args.add_argument('--save_each_inst', type=str2bool, default=False)
|
371 |
inference_args.add_argument('--do_not_separate', type=str2bool, default=False)
|
372 |
+
inference_args.add_argument('--separation_model', type=str, default='htdemucs')
|
373 |
# FX normalization
|
374 |
inference_args.add_argument('--normalize_input', type=str2bool, default=True)
|
375 |
inference_args.add_argument('--normalization_order', type=str2bool, default=['loudness', 'eq', 'compression', 'imager', 'loudness']) # Effects to be normalized, order matters
|
|
|
379 |
|
380 |
device_args = parser.add_argument_group('Device args')
|
381 |
device_args.add_argument('--workers', type=int, default=1)
|
|
|
382 |
device_args.add_argument('--batch_size', type=int, default=1) # for processing long audio
|
|
|
383 |
|
384 |
args = parser.parse_args()
|
385 |
|
|
|
389 |
args.cfg_encoder = configs['Effects_Encoder']['default']
|
390 |
args.cfg_converter = configs['TCN']['default']
|
391 |
|
392 |
+
return args
|
393 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|