Spaces:
Runtime error
Runtime error
Update audio_diffusion_attacks_forhf/src/test_encoder_attack.py
Browse files
audio_diffusion_attacks_forhf/src/test_encoder_attack.py
CHANGED
@@ -305,13 +305,14 @@ def poison_audio(waveform, sample_rate, encoders, audio_difference_weights=[1],
|
|
305 |
print('step', step, 'loss', loss.item(), 'latent loss', latent_diff.item(), 'audio loss', waveform_diff.item(), 'c_waveform_loss', c_waveform_loss.item(), 'c_stft_loss', c_stft_loss.item(), 'l1_loss', l1_loss.item())
|
306 |
latent_diff=latent_diff.detach().item()
|
307 |
|
308 |
-
audio_save=torch.reshape((noise+waveform), (2, waveform.shape[2]))[0, :audio_len].detach().cpu().numpy().flatten()
|
309 |
#Andy removed: wandb.log({f"perturbed cos_dist_{latent_diff}_diff_weight_{diff_weight}_{audio_name}": wandb.Audio(audio_save, sample_rate=sample_rate)}, step=step)
|
310 |
-
|
311 |
-
|
312 |
|
313 |
-
#Andy
|
314 |
-
|
|
|
|
|
|
|
315 |
|
316 |
# encoders = [ArchiSound.from_pretrained('autoencoder1d-AT-v1'),
|
317 |
# ArchiSound.from_pretrained("dmae1d-ATC64-v2"),
|
|
|
305 |
print('step', step, 'loss', loss.item(), 'latent loss', latent_diff.item(), 'audio loss', waveform_diff.item(), 'c_waveform_loss', c_waveform_loss.item(), 'c_stft_loss', c_stft_loss.item(), 'l1_loss', l1_loss.item())
|
306 |
latent_diff=latent_diff.detach().item()
|
307 |
|
308 |
+
#Andy removed: audio_save=torch.reshape((noise+waveform), (2, waveform.shape[2]))[0, :audio_len].detach().cpu().numpy().flatten()
|
309 |
#Andy removed: wandb.log({f"perturbed cos_dist_{latent_diff}_diff_weight_{diff_weight}_{audio_name}": wandb.Audio(audio_save, sample_rate=sample_rate)}, step=step)
|
|
|
|
|
310 |
|
311 |
+
#Andy moved from inside the loop:
|
312 |
+
music_gen_eval_dict, unprotected_gen, protected_gen=music_gen_eval.eval(waveform, noise+waveform)
|
313 |
+
|
314 |
+
#Andy edited: torchaudio.save(os.path.join(audio_folder, f"protected_{audio_name}_{audio_len}_mel_{latent_diff}_diff_weight_{waveform_diff}"), torch.reshape((noise+waveform).detach().cpu(), (2, waveform.shape[2])), sample_rate)
|
315 |
+
return (torch.reshape((noise+waveform).detach().cpu(), (2, waveform.shape[2]))), music_gen_eval_dict, unprotected_gen, protected_gen
|
316 |
|
317 |
# encoders = [ArchiSound.from_pretrained('autoencoder1d-AT-v1'),
|
318 |
# ArchiSound.from_pretrained("dmae1d-ATC64-v2"),
|