jhtonyKoo commited on
Commit
07846b5
1 Parent(s): 077a11b

Update inference/mastering_transfer.py

Browse files
Files changed (1) hide show
  1. inference/mastering_transfer.py +51 -62
inference/mastering_transfer.py CHANGED
@@ -1,7 +1,6 @@
1
  """
2
  Inference code of music style transfer
3
  of the work "Music Mixing Style Transfer: A Contrastive Learning Approach to Disentangle Audio Effects"
4
-
5
  Process : converts the mastering style of the input music recording to that of the refernce music.
6
  files inside the target directory should be organized as follow
7
  "path_to_data_directory"/"song_name_#1"/input.wav
@@ -112,73 +111,64 @@ class Mastering_Style_Transfer_Inference:
112
  # normalized input
113
  output_name_tag = 'output' if self.args.normalize_input else 'output_notnormed'
114
 
115
- input_aud = load_wav_segment(input_track_path)
116
- reference_aud = load_wav_segment(reference_track_path)
117
 
118
  # input_stems, reference_stems, dir_name
119
  print(f"---inference file name : {dir_name[0]}---")
120
  cur_out_dir = dir_name[0].replace(self.target_dir, self.output_dir)
121
  os.makedirs(cur_out_dir, exist_ok=True)
122
- ''' stem-level inference '''
123
- inst_outputs = []
124
- for cur_inst_idx, cur_inst_name in enumerate(self.args.instruments):
125
- print(f'\t{cur_inst_name}...')
126
- ''' segmentize whole songs into batch '''
127
- if len(input_stems[0][cur_inst_idx][0]) > self.args.segment_length:
128
- cur_inst_input_stem = self.batchwise_segmentization(input_stems[0][cur_inst_idx], \
129
- dir_name[0], \
130
- segment_length=self.args.segment_length, \
131
- discard_last=False)
132
- else:
133
- cur_inst_input_stem = [input_stems[:, cur_inst_idx]]
134
- if len(reference_stems[0][cur_inst_idx][0]) > self.args.segment_length*2:
135
- cur_inst_reference_stem = self.batchwise_segmentization(reference_stems[0][cur_inst_idx], \
136
- dir_name[0], \
137
- segment_length=self.args.segment_length_ref, \
138
- discard_last=False)
139
- else:
140
- cur_inst_reference_stem = [reference_stems[:, cur_inst_idx]]
141
-
142
- ''' inference '''
143
- # first extract reference style embedding
144
- infered_ref_data_list = []
145
- for cur_ref_data in cur_inst_reference_stem:
146
- cur_ref_data = cur_ref_data.to(self.device)
147
- # Effects Encoder inference
148
- with torch.no_grad():
149
- self.models["effects_encoder"].eval()
150
- reference_feature = self.models["effects_encoder"](cur_ref_data)
151
- infered_ref_data_list.append(reference_feature)
152
- # compute average value from the extracted exbeddings
153
- infered_ref_data = torch.stack(infered_ref_data_list)
154
- infered_ref_data_avg = torch.mean(infered_ref_data.reshape(infered_ref_data.shape[0]*infered_ref_data.shape[1], infered_ref_data.shape[2]), axis=0)
155
-
156
- # mixing style converter
157
- infered_data_list = []
158
- for cur_data in cur_inst_input_stem:
159
- cur_data = cur_data.to(self.device)
160
- with torch.no_grad():
161
- self.models["mastering_converter"].eval()
162
- infered_data = self.models["mastering_converter"](cur_data, infered_ref_data_avg.unsqueeze(0))
163
- infered_data_list.append(infered_data.cpu().detach())
164
-
165
- # combine back to whole song
166
- for cur_idx, cur_batch_infered_data in enumerate(infered_data_list):
167
- cur_infered_data_sequential = torch.cat(torch.unbind(cur_batch_infered_data, dim=0), dim=-1)
168
- fin_data_out = cur_infered_data_sequential if cur_idx==0 else torch.cat((fin_data_out, cur_infered_data_sequential), dim=-1)
169
- # final output of current instrument
170
- fin_data_out_inst = fin_data_out[:, :input_stems[0][cur_inst_idx].shape[-1]].numpy()
171
-
172
- inst_outputs.append(fin_data_out_inst)
173
- # save output of each instrument
174
- if self.args.save_each_inst:
175
- sf.write(os.path.join(cur_out_dir, f"{cur_inst_name}_{output_name_tag}.wav"), fin_data_out_inst.transpose(-1, -2), self.args.sample_rate, 'PCM_16')
176
  # remix
177
- fin_data_out_mix = sum(inst_outputs)
178
- fin_output_path = os.path.join(cur_out_dir, f"mixture_{output_name_tag}.wav")
179
- sf.write(fin_output_path, fin_data_out_mix.transpose(-1, -2), self.args.sample_rate, 'PCM_16')
180
 
181
- return fin_output_path
182
 
183
 
184
  # Inference whole song
@@ -375,4 +365,3 @@ def set_up_mastering(start_point_in_second=0, duration_in_second=30):
375
  args.cfg_converter = configs['TCN']['default']
376
 
377
  return args
378
-
 
1
  """
2
  Inference code of music style transfer
3
  of the work "Music Mixing Style Transfer: A Contrastive Learning Approach to Disentangle Audio Effects"
 
4
  Process : converts the mastering style of the input music recording to that of the refernce music.
5
  files inside the target directory should be organized as follow
6
  "path_to_data_directory"/"song_name_#1"/input.wav
 
111
  # normalized input
112
  output_name_tag = 'output' if self.args.normalize_input else 'output_notnormed'
113
 
114
+ input_aud = load_wav_segment(input_track_path, axis=0)
115
+ reference_aud = load_wav_segment(reference_track_path, axis=0)
116
 
117
  # input_stems, reference_stems, dir_name
118
  print(f"---inference file name : {dir_name[0]}---")
119
  cur_out_dir = dir_name[0].replace(self.target_dir, self.output_dir)
120
  os.makedirs(cur_out_dir, exist_ok=True)
121
+ ''' segmentize whole songs into batch '''
122
+ if input_aud.shape[1] > self.args.segment_length:
123
+ cur_inst_input_stem = self.batchwise_segmentization(input_aud, \
124
+ dir_name[0], \
125
+ segment_length=self.args.segment_length, \
126
+ discard_last=False)
127
+ else:
128
+ cur_inst_input_stem = [input_aud.unsqueeze(0)]
129
+ if reference_aud.shape[1] > self.args.segment_length*2:
130
+ cur_inst_reference_stem = self.batchwise_segmentization(reference_aud, \
131
+ dir_name[0], \
132
+ segment_length=self.args.segment_length_ref, \
133
+ discard_last=False)
134
+ else:
135
+ cur_inst_reference_stem = [reference_aud.unsqueeze(0)]
136
+
137
+ ''' inference '''
138
+ # first extract reference style embedding
139
+ infered_ref_data_list = []
140
+ for cur_ref_data in cur_inst_reference_stem:
141
+ cur_ref_data = cur_ref_data.to(self.device)
142
+ # Effects Encoder inference
143
+ with torch.no_grad():
144
+ self.models["effects_encoder"].eval()
145
+ reference_feature = self.models["effects_encoder"](cur_ref_data)
146
+ infered_ref_data_list.append(reference_feature)
147
+ # compute average value from the extracted exbeddings
148
+ infered_ref_data = torch.stack(infered_ref_data_list)
149
+ infered_ref_data_avg = torch.mean(infered_ref_data.reshape(infered_ref_data.shape[0]*infered_ref_data.shape[1], infered_ref_data.shape[2]), axis=0)
150
+
151
+ # mastering style converter
152
+ infered_data_list = []
153
+ for cur_data in cur_inst_input_stem:
154
+ cur_data = cur_data.to(self.device)
155
+ with torch.no_grad():
156
+ self.models["mastering_converter"].eval()
157
+ infered_data = self.models["mastering_converter"](cur_data, infered_ref_data_avg.unsqueeze(0))
158
+ infered_data_list.append(infered_data.cpu().detach())
159
+
160
+ # combine back to whole song
161
+ for cur_idx, cur_batch_infered_data in enumerate(infered_data_list):
162
+ cur_infered_data_sequential = torch.cat(torch.unbind(cur_batch_infered_data, dim=0), dim=-1)
163
+ fin_data_out = cur_infered_data_sequential if cur_idx==0 else torch.cat((fin_data_out, cur_infered_data_sequential), dim=-1)
164
+ # final output of current instrument
165
+ fin_data_out_mastered = fin_data_out[:, :input_aud.shape[-1]].numpy()
166
+
 
 
 
 
 
 
 
 
167
  # remix
168
+ fin_output_path = os.path.join(cur_out_dir, f"remastered_output.wav")
169
+ sf.write(fin_output_path, fin_data_out_mastered.transpose(-1, -2), self.args.sample_rate, 'PCM_16')
 
170
 
171
+ return fin_output_path_mastering
172
 
173
 
174
  # Inference whole song
 
365
  args.cfg_converter = configs['TCN']['default']
366
 
367
  return args