jhtonyKoo commited on
Commit
8c9ee04
·
1 Parent(s): b780126

update loss

Browse files
app.py CHANGED
@@ -94,7 +94,8 @@ def process_audio(input_audio, reference_audio):
94
 
95
  return (sr, output_audio), param_output, (sr, normalized_input)
96
 
97
- def perform_ito(input_audio, reference_audio, ito_reference_audio, num_steps, optimizer, learning_rate, af_weights):
 
98
  if ito_reference_audio is None:
99
  ito_reference_audio = reference_audio
100
  af_weights = [float(w.strip()) for w in af_weights.split(',')]
@@ -104,7 +105,10 @@ def perform_ito(input_audio, reference_audio, ito_reference_audio, num_steps, op
104
  'learning_rate': learning_rate,
105
  'num_steps': num_steps,
106
  'af_weights': af_weights,
107
- 'sample_rate': args.sample_rate
 
 
 
108
  }
109
 
110
  input_tensor = mastering_transfer.preprocess_audio(input_audio, args.sample_rate)
@@ -219,7 +223,22 @@ with gr.Blocks() as demo:
219
  optimizer = gr.Dropdown(["Adam", "RAdam", "SGD"], value="RAdam", label="Optimizer")
220
  learning_rate = gr.Slider(minimum=0.0001, maximum=0.1, value=0.001, step=0.0001, label="Learning Rate")
221
  af_weights = gr.Textbox(label="AudioFeatureLoss Weights (comma-separated)", value="0.1,0.001,1.0,1.0,0.1")
222
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
  ito_button = gr.Button("Perform ITO")
224
 
225
  with gr.Row():
@@ -243,7 +262,7 @@ with gr.Blocks() as demo:
243
 
244
  ito_button.click(
245
  perform_ito,
246
- inputs=[normalized_input, reference_audio, ito_reference_audio, num_steps, optimizer, learning_rate, af_weights],
247
  outputs=[ito_output_audio, ito_param_output, ito_step_slider, ito_log, ito_loss_plot, all_results]
248
  ).then(
249
  update_ito_output,
 
94
 
95
  return (sr, output_audio), param_output, (sr, normalized_input)
96
 
97
+ # def perform_ito(input_audio, reference_audio, ito_reference_audio, num_steps, optimizer, learning_rate, af_weights):
98
+ def perform_ito(input_audio, reference_audio, ito_reference_audio, num_steps, optimizer, learning_rate, af_weights, loss_function, clap_target_type, clap_text_prompt):
99
  if ito_reference_audio is None:
100
  ito_reference_audio = reference_audio
101
  af_weights = [float(w.strip()) for w in af_weights.split(',')]
 
105
  'learning_rate': learning_rate,
106
  'num_steps': num_steps,
107
  'af_weights': af_weights,
108
+ 'sample_rate': args.sample_rate,
109
+ 'loss_function': loss_function,
110
+ 'clap_target_type': clap_target_type,
111
+ 'clap_text_prompt': clap_text_prompt
112
  }
113
 
114
  input_tensor = mastering_transfer.preprocess_audio(input_audio, args.sample_rate)
 
223
  optimizer = gr.Dropdown(["Adam", "RAdam", "SGD"], value="RAdam", label="Optimizer")
224
  learning_rate = gr.Slider(minimum=0.0001, maximum=0.1, value=0.001, step=0.0001, label="Learning Rate")
225
  af_weights = gr.Textbox(label="AudioFeatureLoss Weights (comma-separated)", value="0.1,0.001,1.0,1.0,0.1")
226
+ loss_function = gr.Radio(["AudioFeatureLoss", "CLAPFeatureLoss"], label="Loss Function", value="AudioFeatureLoss")
227
+ clap_target_type = gr.Radio(["Audio", "Text"], label="CLAP Target Type", value="Audio", visible=False)
228
+ clap_text_prompt = gr.Textbox(label="CLAP Text Prompt", visible=False)
229
+
230
+ def update_clap_options(loss_function):
231
+ if loss_function == "CLAPFeatureLoss":
232
+ return gr.update(visible=True), gr.update(visible=True)
233
+ else:
234
+ return gr.update(visible=False), gr.update(visible=False)
235
+
236
+ loss_function.change(
237
+ update_clap_options,
238
+ inputs=[loss_function],
239
+ outputs=[clap_target_type, clap_text_prompt]
240
+ )
241
+
242
  ito_button = gr.Button("Perform ITO")
243
 
244
  with gr.Row():
 
262
 
263
  ito_button.click(
264
  perform_ito,
265
+ inputs=[normalized_input, reference_audio, ito_reference_audio, num_steps, optimizer, learning_rate, af_weights, loss_function, clap_target_type, clap_text_prompt],
266
  outputs=[ito_output_audio, ito_param_output, ito_step_slider, ito_log, ito_loss_plot, all_results]
267
  ).then(
268
  update_ito_output,
inference.py CHANGED
@@ -34,7 +34,14 @@ class MasteringStyleTransfer:
34
  self.fx_normalizer = Audio_Effects_Normalizer(precomputed_feature_path=args.fx_norm_feature_path, \
35
  STEMS=['mixture'], \
36
  EFFECTS=['eq', 'imager', 'loudness'])
37
-
 
 
 
 
 
 
 
38
 
39
  def load_effects_encoder(self):
40
  effects_encoder = Effects_Encoder(self.args.cfg_enc)
@@ -70,13 +77,6 @@ class MasteringStyleTransfer:
70
  fit_embedding = torch.nn.Parameter(initial_reference_feature)
71
  optimizer = getattr(torch.optim, ito_config['optimizer'])([fit_embedding], lr=ito_config['learning_rate'])
72
 
73
- af_loss = AudioFeatureLoss(
74
- weights=ito_config['af_weights'],
75
- sample_rate=ito_config['sample_rate'],
76
- stem_separation=False,
77
- use_clap=False
78
- )
79
-
80
  min_loss = float('inf')
81
  min_loss_step = 0
82
  all_results = []
@@ -87,7 +87,15 @@ class MasteringStyleTransfer:
87
  output_audio = self.mastering_converter(input_tensor, fit_embedding)
88
  current_params = self.mastering_converter.get_last_predicted_params()
89
 
90
- losses = af_loss(output_audio, reference_tensor)
 
 
 
 
 
 
 
 
91
  total_loss = sum(losses.values())
92
 
93
  if total_loss < min_loss:
 
34
  self.fx_normalizer = Audio_Effects_Normalizer(precomputed_feature_path=args.fx_norm_feature_path, \
35
  STEMS=['mixture'], \
36
  EFFECTS=['eq', 'imager', 'loudness'])
37
+ # Loss functions
38
+ self.af_loss = AudioFeatureLoss(
39
+ weights=ito_config['af_weights'],
40
+ sample_rate=ito_config['sample_rate'],
41
+ stem_separation=False,
42
+ use_clap=False
43
+ )
44
+ self.clap_loss = CLAPFeatureLoss(distance_fn='cosine')
45
 
46
  def load_effects_encoder(self):
47
  effects_encoder = Effects_Encoder(self.args.cfg_enc)
 
77
  fit_embedding = torch.nn.Parameter(initial_reference_feature)
78
  optimizer = getattr(torch.optim, ito_config['optimizer'])([fit_embedding], lr=ito_config['learning_rate'])
79
 
 
 
 
 
 
 
 
80
  min_loss = float('inf')
81
  min_loss_step = 0
82
  all_results = []
 
87
  output_audio = self.mastering_converter(input_tensor, fit_embedding)
88
  current_params = self.mastering_converter.get_last_predicted_params()
89
 
90
+ # Compute loss
91
+ if ito_config['loss_function'] == 'AudioFeatureLoss':
92
+ losses = self.af_loss(output_audio, reference_tensor)
93
+ elif ito_config['loss_function'] == 'CLAPFeatureLoss':
94
+ if ito_config['clap_target_type'] == 'Audio':
95
+ target = ito_reference_tensor
96
+ else:
97
+ target = ito_config['clap_text_prompt']
98
+ losses = self.clap_loss(est_targets, target, self.args.sample_rate)
99
  total_loss = sum(losses.values())
100
 
101
  if total_loss < min_loss:
modules/__pycache__/loss.cpython-311.pyc CHANGED
Binary files a/modules/__pycache__/loss.cpython-311.pyc and b/modules/__pycache__/loss.cpython-311.pyc differ
 
modules/loss.py CHANGED
@@ -196,36 +196,50 @@ class CLAPFeatureLoss(nn.Module):
196
  else:
197
  raise ValueError(f"Unsupported distance function: {distance_fn}")
198
 
199
- def forward(self, input_audio, target_audio, sample_rate):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
  # Ensure input is in the correct shape (N, C, T)
201
- if input_audio.dim() == 2:
202
- input_audio = input_audio.unsqueeze(1)
203
- if target_audio.dim() == 2:
204
- target_audio = target_audio.unsqueeze(1)
205
 
206
  # Convert to mono if stereo
207
- if input_audio.shape[1] > 1:
208
- input_audio = input_audio.mean(dim=1, keepdim=True)
209
- if target_audio.shape[1] > 1:
210
- target_audio = target_audio.mean(dim=1, keepdim=True)
211
 
212
  # Resample if necessary
213
  if sample_rate != self.target_sample_rate:
214
- input_audio = self.resample(input_audio, sample_rate)
215
- target_audio = self.resample(target_audio, sample_rate)
216
 
217
  # Quantize audio data
218
- input_audio = self.quantize(input_audio)
219
- target_audio = self.quantize(target_audio)
220
 
221
  # Get CLAP embeddings
222
- input_embed = self.model.get_audio_embedding_from_data(x=input_audio, use_tensor=True)
223
- target_embed = self.model.get_audio_embedding_from_data(x=target_audio, use_tensor=True)
224
 
225
- # Compute loss using the specified distance function
226
- loss = self.compute_distance(input_embed, target_embed)
227
-
228
- return loss
 
 
 
229
 
230
  def quantize(self, audio):
231
  audio = audio.squeeze(1) # Remove channel dimension
@@ -490,4 +504,19 @@ class AudioFeatureLoss(torch.nn.Module):
490
  val = torch.nn.functional.mse_loss(input_transform, target_transform)
491
  losses[key] = weight * val * self.source_weights[stem_idx]
492
 
493
- return losses
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
  else:
197
  raise ValueError(f"Unsupported distance function: {distance_fn}")
198
 
199
+ def forward(self, input_audio, target, sample_rate):
200
+ # Process input audio
201
+ input_embed = self.process_audio(input_audio, sample_rate)
202
+
203
+ # Process target (audio or text)
204
+ if isinstance(target, torch.Tensor):
205
+ target_embed = self.process_audio(target, sample_rate)
206
+ elif isinstance(target, str) or (isinstance(target, list) and isinstance(target[0], str)):
207
+ target_embed = self.process_text(target)
208
+ else:
209
+ raise ValueError("Target must be either audio tensor or text (string or list of strings)")
210
+
211
+ # Compute loss using the specified distance function
212
+ loss = self.compute_distance(input_embed, target_embed)
213
+
214
+ return loss
215
+
216
+ def process_audio(self, audio, sample_rate):
217
  # Ensure input is in the correct shape (N, C, T)
218
+ if audio.dim() == 2:
219
+ audio = audio.unsqueeze(1)
 
 
220
 
221
  # Convert to mono if stereo
222
+ if audio.shape[1] > 1:
223
+ audio = audio.mean(dim=1, keepdim=True)
 
 
224
 
225
  # Resample if necessary
226
  if sample_rate != self.target_sample_rate:
227
+ audio = self.resample(audio, sample_rate)
 
228
 
229
  # Quantize audio data
230
+ audio = self.quantize(audio)
 
231
 
232
  # Get CLAP embeddings
233
+ embed = self.model.get_audio_embedding_from_data(x=audio, use_tensor=True)
234
+ return embed
235
 
236
+ def process_text(self, text):
237
+ # Get CLAP embeddings for text
238
+ # ensure input is a list of strings
239
+ if not isinstance(text, list):
240
+ text = [text]
241
+ embed = self.model.get_text_embedding(text, use_tensor=True)
242
+ return embed
243
 
244
  def quantize(self, audio):
245
  audio = audio.squeeze(1) # Remove channel dimension
 
504
  val = torch.nn.functional.mse_loss(input_transform, target_transform)
505
  losses[key] = weight * val * self.source_weights[stem_idx]
506
 
507
+ return losses
508
+
509
+
510
+
511
+ if __name__ == "__main__":
512
+ clap_loss = CLAPFeatureLoss(distance_fn='cosine')
513
+
514
+ input_audio = torch.randn(1, 2, 44100)
515
+ target_audio = torch.randn(1, 2, 44100)
516
+ target_text = "This is a test"
517
+ sample_rate = 44100
518
+ loss = clap_loss(input_audio, target_audio, sample_rate)
519
+ print(loss)
520
+ loss = clap_loss(input_audio, target_text, sample_rate)
521
+ print(loss)
522
+