ALeLacheur commited on
Commit
e5d0122
·
verified ·
1 Parent(s): 8fae7b2

Update audio_diffusion_attacks_forhf/src/test_encoder_attack.py

Browse files
audio_diffusion_attacks_forhf/src/test_encoder_attack.py CHANGED
@@ -119,8 +119,9 @@ def poison_audio(waveform, sample_rate, encoders, audio_difference_weights=[1],
119
  waveform=inputs['input_values'][0]
120
  #Andy removed: wandb.log({f"unperturbed {audio_name}": wandb.Audio(waveform[0].detach().numpy().flatten(), sample_rate=sample_rate)}, step=0)
121
  waveform=torch.reshape(waveform, (1, waveform.shape[0], waveform.shape[1]))
122
- waveform=waveform.to(device='cuda')
123
- inputs["padding_mask"]=inputs["padding_mask"].to(device='cuda')
 
124
 
125
  if method=="encoder":
126
  unperturbed_waveform=waveform.clone().detach()
@@ -135,8 +136,10 @@ def poison_audio(waveform, sample_rate, encoders, audio_difference_weights=[1],
135
  style_waveform=style_inputs['input_values'][0]
136
  #Andy removed: wandb.log({f"transfer style": wandb.Audio(style_waveform[0].detach().numpy().flatten(), sample_rate=sample_rate)}, step=0)
137
  style_waveform=torch.reshape(style_waveform, (1, style_waveform.shape[0], style_waveform.shape[1]))
138
- style_waveform=style_waveform.to(device='cuda')
139
- style_inputs["padding_mask"]=style_inputs["padding_mask"].to(device='cuda')
 
 
140
  # unperturbed_latent=encoder(waveform, inputs["padding_mask"]).audio_values.detach()
141
  unperturbed_waveform=style_waveform.clone().detach()
142
  unperturbed_latents=[]
@@ -148,7 +151,7 @@ def poison_audio(waveform, sample_rate, encoders, audio_difference_weights=[1],
148
 
149
 
150
  noise=torch.normal(torch.zeros(waveform.shape), 0.0)
151
- noise=noise.to(device='cuda')
152
  noise.requires_grad=True
153
  # waveform=torch.nn.parameter.Parameter(waveform)
154
 
@@ -169,10 +172,10 @@ def poison_audio(waveform, sample_rate, encoders, audio_difference_weights=[1],
169
  # )
170
 
171
  downsample = torchaudio.transforms.Resample(sample_rate, 22050)
172
- downsample=downsample.to(device='cuda')
173
  cos = torch.nn.CosineSimilarity()
174
  mrstft = auraloss.perceptual.FIRFilter()#auraloss.time.SISDRLoss()#torch.nn.functional.l1_loss
175
- mrstft.to(device='cuda')
176
 
177
  waveform_loss = losses.L1Loss()
178
  stft_loss = losses.MultiScaleSTFTLoss()
 
119
  waveform=inputs['input_values'][0]
120
  #Andy removed: wandb.log({f"unperturbed {audio_name}": wandb.Audio(waveform[0].detach().numpy().flatten(), sample_rate=sample_rate)}, step=0)
121
  waveform=torch.reshape(waveform, (1, waveform.shape[0], waveform.shape[1]))
122
+ #Andy removed: waveform=waveform.to(device='cuda')
123
+ #Andy edited: inputs["padding_mask"]=inputs["padding_mask"].to(device='cuda')
124
+ inputs["padding_mask"]=inputs["padding_mask"]
125
 
126
  if method=="encoder":
127
  unperturbed_waveform=waveform.clone().detach()
 
136
  style_waveform=style_inputs['input_values'][0]
137
  #Andy removed: wandb.log({f"transfer style": wandb.Audio(style_waveform[0].detach().numpy().flatten(), sample_rate=sample_rate)}, step=0)
138
  style_waveform=torch.reshape(style_waveform, (1, style_waveform.shape[0], style_waveform.shape[1]))
139
+ #Andy edited: style_waveform=style_waveform.to(device='cuda')
140
+ style_waveform=style_waveform
141
+ #Andy edited: style_inputs["padding_mask"]=style_inputs["padding_mask"].to(device='cuda')
142
+ style_inputs["padding_mask"]=style_inputs["padding_mask"]
143
  # unperturbed_latent=encoder(waveform, inputs["padding_mask"]).audio_values.detach()
144
  unperturbed_waveform=style_waveform.clone().detach()
145
  unperturbed_latents=[]
 
151
 
152
 
153
  noise=torch.normal(torch.zeros(waveform.shape), 0.0)
154
+ #Andy removed: noise=noise.to(device='cuda')
155
  noise.requires_grad=True
156
  # waveform=torch.nn.parameter.Parameter(waveform)
157
 
 
172
  # )
173
 
174
  downsample = torchaudio.transforms.Resample(sample_rate, 22050)
175
+ #Andy removed: downsample=downsample.to(device='cuda')
176
  cos = torch.nn.CosineSimilarity()
177
  mrstft = auraloss.perceptual.FIRFilter()#auraloss.time.SISDRLoss()#torch.nn.functional.l1_loss
178
+ #Andy removed: mrstft.to(device='cuda')
179
 
180
  waveform_loss = losses.L1Loss()
181
  stft_loss = losses.MultiScaleSTFTLoss()