jhtonyKoo commited on
Commit
78bac9e
·
1 Parent(s): 278b4aa

update loss

Browse files
Files changed (2) hide show
  1. inference.py +1 -1
  2. modules/loss.py +91 -24
inference.py CHANGED
@@ -99,6 +99,7 @@ class MasteringStyleTransfer:
99
  target = ito_config['clap_text_prompt']
100
  print(f'ito_config clap_distance_fn: {ito_config["clap_distance_fn"]}')
101
  total_loss = self.clap_loss(output_audio, target, self.args.sample_rate, distance_fn=ito_config['clap_distance_fn'])
 
102
 
103
  if total_loss < min_loss:
104
  min_loss = total_loss.item()
@@ -243,7 +244,6 @@ class MasteringStyleTransfer:
243
  if isinstance(param_value, torch.Tensor):
244
  param_value = param_value.item()
245
 
246
- print(f"fx name: {fx_name} param_name: {param_name}")
247
  if fx_name in param_mapper and param_name in param_mapper[fx_name]:
248
  friendly_name, unit, min_val, max_val = param_mapper[fx_name][param_name]
249
  if unit=='%':
 
99
  target = ito_config['clap_text_prompt']
100
  print(f'ito_config clap_distance_fn: {ito_config["clap_distance_fn"]}')
101
  total_loss = self.clap_loss(output_audio, target, self.args.sample_rate, distance_fn=ito_config['clap_distance_fn'])
102
+ print(f'total_loss: {total_loss}')
103
 
104
  if total_loss < min_loss:
105
  min_loss = total_loss.item()
 
244
  if isinstance(param_value, torch.Tensor):
245
  param_value = param_value.item()
246
 
 
247
  if fx_name in param_mapper and param_name in param_mapper[fx_name]:
248
  friendly_name, unit, min_val, max_val = param_mapper[fx_name][param_name]
249
  if unit=='%':
modules/loss.py CHANGED
@@ -185,25 +185,35 @@ class CLAPFeatureLoss(nn.Module):
185
  self.target_sample_rate = 48000 # CLAP expects 48kHz audio
186
  self.model = laion_clap.CLAP_Module(enable_fusion=False)
187
  self.model.load_ckpt() # download the default pretrained checkpoint
 
 
 
 
188
 
189
- def forward(self, input_audio, target, sample_rate, distance_fn='cosine'):
190
  # Process input audio
191
- input_embed = self.process_audio(input_audio, sample_rate)
 
 
 
 
192
 
193
  # Process target (audio or text)
194
- if isinstance(target, torch.Tensor):
195
- target_embed = self.process_audio(target, sample_rate)
196
- elif isinstance(target, str) or (isinstance(target, list) and isinstance(target[0], str)):
197
- target_embed = self.process_text(target)
198
- else:
199
- raise ValueError("Target must be either audio tensor or text (string or list of strings)")
 
 
200
 
201
  # Compute loss using the specified distance function
202
  loss = self.compute_distance(input_embed, target_embed, distance_fn)
203
 
204
  return loss
205
 
206
- def process_audio(self, audio, sample_rate):
207
  # Ensure input is in the correct shape (N, C, T)
208
  if audio.dim() == 2:
209
  audio = audio.unsqueeze(1)
@@ -219,19 +229,7 @@ class CLAPFeatureLoss(nn.Module):
219
  # Quantize audio data
220
  audio = self.quantize(audio)
221
 
222
- # Get CLAP embeddings
223
- with torch.no_grad():
224
- embed = self.model.get_audio_embedding_from_data(x=audio, use_tensor=True)
225
- return embed
226
-
227
- def process_text(self, text):
228
- # Get CLAP embeddings for text
229
- # ensure input is a list of strings
230
- if not isinstance(text, list):
231
- text = [text]
232
- with torch.no_grad():
233
- embed = self.model.get_text_embedding(text, use_tensor=True)
234
- return embed
235
 
236
  def compute_distance(self, x, y, distance_fn):
237
  if distance_fn == 'mse':
@@ -249,11 +247,80 @@ class CLAPFeatureLoss(nn.Module):
249
  audio = (audio * 32767.0).to(torch.int16).to(torch.float32) / 32767.0
250
  return audio
251
 
252
- def resample(self, audio, input_sample_rate):
253
  resampler = torchaudio.transforms.Resample(
254
- orig_freq=input_sample_rate, new_freq=self.target_sample_rate
255
  ).to(audio.device)
256
  return resampler(audio)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
 
258
 
259
  """
 
185
  self.target_sample_rate = 48000 # CLAP expects 48kHz audio
186
  self.model = laion_clap.CLAP_Module(enable_fusion=False)
187
  self.model.load_ckpt() # download the default pretrained checkpoint
188
+
189
+ # Freeze the CLAP model parameters
190
+ for param in self.model.parameters():
191
+ param.requires_grad = False
192
 
193
+ def forward(self, input_audio, target, sample_rate, distance_fn='mse'):
194
  # Process input audio
195
+ with torch.no_grad():
196
+ input_audio = self.preprocess_audio(input_audio, sample_rate)
197
+
198
+ with torch.enable_grad():
199
+ input_embed = self.model.get_audio_embedding_from_data(x=input_audio, use_tensor=True)
200
 
201
  # Process target (audio or text)
202
+ with torch.no_grad():
203
+ if isinstance(target, torch.Tensor):
204
+ target_audio = self.preprocess_audio(target, sample_rate)
205
+ target_embed = self.model.get_audio_embedding_from_data(x=target_audio, use_tensor=True)
206
+ elif isinstance(target, str) or (isinstance(target, list) and isinstance(target[0], str)):
207
+ target_embed = self.model.get_text_embedding(target, use_tensor=True)
208
+ else:
209
+ raise ValueError("Target must be either audio tensor or text (string or list of strings)")
210
 
211
  # Compute loss using the specified distance function
212
  loss = self.compute_distance(input_embed, target_embed, distance_fn)
213
 
214
  return loss
215
 
216
+ def preprocess_audio(self, audio, sample_rate):
217
  # Ensure input is in the correct shape (N, C, T)
218
  if audio.dim() == 2:
219
  audio = audio.unsqueeze(1)
 
229
  # Quantize audio data
230
  audio = self.quantize(audio)
231
 
232
+ return audio
 
 
 
 
 
 
 
 
 
 
 
 
233
 
234
  def compute_distance(self, x, y, distance_fn):
235
  if distance_fn == 'mse':
 
247
  audio = (audio * 32767.0).to(torch.int16).to(torch.float32) / 32767.0
248
  return audio
249
 
250
+ def resample(self, audio, orig_sample_rate):
251
  resampler = torchaudio.transforms.Resample(
252
+ orig_freq=orig_sample_rate, new_freq=self.target_sample_rate
253
  ).to(audio.device)
254
  return resampler(audio)
255
+
256
+ # def forward(self, input_audio, target, sample_rate, distance_fn='cosine'):
257
+ # # Process input audio
258
+ # input_embed = self.process_audio(input_audio, sample_rate)
259
+
260
+ # # Process target (audio or text)
261
+ # if isinstance(target, torch.Tensor):
262
+ # target_embed = self.process_audio(target, sample_rate)
263
+ # elif isinstance(target, str) or (isinstance(target, list) and isinstance(target[0], str)):
264
+ # target_embed = self.process_text(target)
265
+ # else:
266
+ # raise ValueError("Target must be either audio tensor or text (string or list of strings)")
267
+
268
+ # # Compute loss using the specified distance function
269
+ # loss = self.compute_distance(input_embed, target_embed, distance_fn)
270
+
271
+ # return loss
272
+
273
+ # def process_audio(self, audio, sample_rate):
274
+ # # Ensure input is in the correct shape (N, C, T)
275
+ # if audio.dim() == 2:
276
+ # audio = audio.unsqueeze(1)
277
+
278
+ # # Convert to mono if stereo
279
+ # if audio.shape[1] > 1:
280
+ # audio = audio.mean(dim=1, keepdim=True)
281
+
282
+ # # Resample if necessary
283
+ # if sample_rate != self.target_sample_rate:
284
+ # audio = self.resample(audio, sample_rate)
285
+
286
+ # # Quantize audio data
287
+ # audio = self.quantize(audio)
288
+
289
+ # # Get CLAP embeddings
290
+ # with torch.no_grad():
291
+ # embed = self.model.get_audio_embedding_from_data(x=audio, use_tensor=True)
292
+ # return embed
293
+
294
+ # def process_text(self, text):
295
+ # # Get CLAP embeddings for text
296
+ # # ensure input is a list of strings
297
+ # if not isinstance(text, list):
298
+ # text = [text]
299
+ # with torch.no_grad():
300
+ # embed = self.model.get_text_embedding(text, use_tensor=True)
301
+ # return embed
302
+
303
+ # def compute_distance(self, x, y, distance_fn):
304
+ # if distance_fn == 'mse':
305
+ # return F.mse_loss(x, y)
306
+ # elif distance_fn == 'l1':
307
+ # return F.l1_loss(x, y)
308
+ # elif distance_fn == 'cosine':
309
+ # return 1 - F.cosine_similarity(x, y).mean()
310
+ # else:
311
+ # raise ValueError(f"Unsupported distance function: {distance_fn}")
312
+
313
+ # def quantize(self, audio):
314
+ # audio = audio.squeeze(1) # Remove channel dimension
315
+ # audio = torch.clamp(audio, -1.0, 1.0)
316
+ # audio = (audio * 32767.0).to(torch.int16).to(torch.float32) / 32767.0
317
+ # return audio
318
+
319
+ # def resample(self, audio, input_sample_rate):
320
+ # resampler = torchaudio.transforms.Resample(
321
+ # orig_freq=input_sample_rate, new_freq=self.target_sample_rate
322
+ # ).to(audio.device)
323
+ # return resampler(audio)
324
 
325
 
326
  """