ALeLacheur commited on
Commit
3ad523c
·
verified ·
1 Parent(s): 22f7428

Update audio_diffusion_attacks_forhf/src/test_encoder_attack.py

Browse files
audio_diffusion_attacks_forhf/src/test_encoder_attack.py CHANGED
@@ -305,13 +305,14 @@ def poison_audio(waveform, sample_rate, encoders, audio_difference_weights=[1],
305
  print('step', step, 'loss', loss.item(), 'latent loss', latent_diff.item(), 'audio loss', waveform_diff.item(), 'c_waveform_loss', c_waveform_loss.item(), 'c_stft_loss', c_stft_loss.item(), 'l1_loss', l1_loss.item())
306
  latent_diff=latent_diff.detach().item()
307
 
308
- audio_save=torch.reshape((noise+waveform), (2, waveform.shape[2]))[0, :audio_len].detach().cpu().numpy().flatten()
309
  #Andy removed: wandb.log({f"perturbed cos_dist_{latent_diff}_diff_weight_{diff_weight}_{audio_name}": wandb.Audio(audio_save, sample_rate=sample_rate)}, step=step)
310
-
311
-
312
 
313
- #Andy removed: torchaudio.save(os.path.join(audio_folder, f"protected_{audio_name}_{audio_len}_mel_{latent_diff}_diff_weight_{waveform_diff}"), torch.reshape((noise+waveform).detach().cpu(), (2, waveform.shape[2])), sample_rate)
314
- return torch.reshape((noise+waveform).detach().cpu(), (2, waveform.shape[2]))
 
 
 
315
 
316
  # encoders = [ArchiSound.from_pretrained('autoencoder1d-AT-v1'),
317
  # ArchiSound.from_pretrained("dmae1d-ATC64-v2"),
 
305
  print('step', step, 'loss', loss.item(), 'latent loss', latent_diff.item(), 'audio loss', waveform_diff.item(), 'c_waveform_loss', c_waveform_loss.item(), 'c_stft_loss', c_stft_loss.item(), 'l1_loss', l1_loss.item())
306
  latent_diff=latent_diff.detach().item()
307
 
308
+ #Andy removed: audio_save=torch.reshape((noise+waveform), (2, waveform.shape[2]))[0, :audio_len].detach().cpu().numpy().flatten()
309
  #Andy removed: wandb.log({f"perturbed cos_dist_{latent_diff}_diff_weight_{diff_weight}_{audio_name}": wandb.Audio(audio_save, sample_rate=sample_rate)}, step=step)
 
 
310
 
311
+ #Andy moved from inside the loop:
312
+ music_gen_eval_dict, unprotected_gen, protected_gen=music_gen_eval.eval(waveform, noise+waveform)
313
+
314
+ #Andy edited: torchaudio.save(os.path.join(audio_folder, f"protected_{audio_name}_{audio_len}_mel_{latent_diff}_diff_weight_{waveform_diff}"), torch.reshape((noise+waveform).detach().cpu(), (2, waveform.shape[2])), sample_rate)
315
+ return (torch.reshape((noise+waveform).detach().cpu(), (2, waveform.shape[2]))), music_gen_eval_dict, unprotected_gen, protected_gen
316
 
317
  # encoders = [ArchiSound.from_pretrained('autoencoder1d-AT-v1'),
318
  # ArchiSound.from_pretrained("dmae1d-ATC64-v2"),