ALeLacheur commited on
Commit
87dbf20
·
verified ·
1 Parent(s): 8646e59

Update audio_diffusion_attacks_forhf/src/test_encoder_attack.py

Browse files
audio_diffusion_attacks_forhf/src/test_encoder_attack.py CHANGED
@@ -86,14 +86,12 @@ def poison_audio(waveform, sample_rate, encoders, audio_difference_weights=[1],
86
  audio_folder: string, path to folder of audio files. Protected audio files will be saved in that folder.
87
  encoders: encoders to protect against. See initialization at end of file.
88
  '''
89
- print("breakpoint 1")
90
  for encoder in encoders:
91
  #Andy removed: encoder.to(device='cuda')
92
  encoder.eval()
93
  for p in encoder.parameters():
94
  p.requires_grad = False
95
 
96
- print("breakpoint 2")
97
  audio_len=1000000
98
  #Andy removed: waveform, sample_rate = torchaudio.load(f"test_audio/Texas Sun.mp3")
99
  if modality=="music":
@@ -101,7 +99,8 @@ def poison_audio(waveform, sample_rate, encoders, audio_difference_weights=[1],
101
  elif modality=="speech":
102
  music_gen_eval=XTTS_Eval(sample_rate)
103
  processor = AutoProcessor.from_pretrained("facebook/encodec_48khz")
104
- loss_fn = cdpam.CDPAM(dev='cuda:0')
 
105
  for p in loss_fn.model.parameters():
106
  p.requires_grad = False
107
 
 
86
  audio_folder: string, path to folder of audio files. Protected audio files will be saved in that folder.
87
  encoders: encoders to protect against. See initialization at end of file.
88
  '''
 
89
  for encoder in encoders:
90
  #Andy removed: encoder.to(device='cuda')
91
  encoder.eval()
92
  for p in encoder.parameters():
93
  p.requires_grad = False
94
 
 
95
  audio_len=1000000
96
  #Andy removed: waveform, sample_rate = torchaudio.load(f"test_audio/Texas Sun.mp3")
97
  if modality=="music":
 
99
  elif modality=="speech":
100
  music_gen_eval=XTTS_Eval(sample_rate)
101
  processor = AutoProcessor.from_pretrained("facebook/encodec_48khz")
102
+ #Andy edited: loss_fn = cdpam.CDPAM(dev='cuda:0')
103
+ loss_fn = cdpam.CDPAM()
104
  for p in loss_fn.model.parameters():
105
  p.requires_grad = False
106