Spaces:
Runtime error
Runtime error
Update audio_diffusion_attacks_forhf/src/test_encoder_attack.py
Browse files
audio_diffusion_attacks_forhf/src/test_encoder_attack.py
CHANGED
@@ -119,8 +119,9 @@ def poison_audio(waveform, sample_rate, encoders, audio_difference_weights=[1],
|
|
119 |
waveform=inputs['input_values'][0]
|
120 |
#Andy removed: wandb.log({f"unperturbed {audio_name}": wandb.Audio(waveform[0].detach().numpy().flatten(), sample_rate=sample_rate)}, step=0)
|
121 |
waveform=torch.reshape(waveform, (1, waveform.shape[0], waveform.shape[1]))
|
122 |
-
waveform=waveform.to(device='cuda')
|
123 |
-
inputs["padding_mask"]=inputs["padding_mask"].to(device='cuda')
|
|
|
124 |
|
125 |
if method=="encoder":
|
126 |
unperturbed_waveform=waveform.clone().detach()
|
@@ -135,8 +136,10 @@ def poison_audio(waveform, sample_rate, encoders, audio_difference_weights=[1],
|
|
135 |
style_waveform=style_inputs['input_values'][0]
|
136 |
#Andy removed: wandb.log({f"transfer style": wandb.Audio(style_waveform[0].detach().numpy().flatten(), sample_rate=sample_rate)}, step=0)
|
137 |
style_waveform=torch.reshape(style_waveform, (1, style_waveform.shape[0], style_waveform.shape[1]))
|
138 |
-
style_waveform=style_waveform.to(device='cuda')
|
139 |
-
|
|
|
|
|
140 |
# unperturbed_latent=encoder(waveform, inputs["padding_mask"]).audio_values.detach()
|
141 |
unperturbed_waveform=style_waveform.clone().detach()
|
142 |
unperturbed_latents=[]
|
@@ -148,7 +151,7 @@ def poison_audio(waveform, sample_rate, encoders, audio_difference_weights=[1],
|
|
148 |
|
149 |
|
150 |
noise=torch.normal(torch.zeros(waveform.shape), 0.0)
|
151 |
-
noise=noise.to(device='cuda')
|
152 |
noise.requires_grad=True
|
153 |
# waveform=torch.nn.parameter.Parameter(waveform)
|
154 |
|
@@ -169,10 +172,10 @@ def poison_audio(waveform, sample_rate, encoders, audio_difference_weights=[1],
|
|
169 |
# )
|
170 |
|
171 |
downsample = torchaudio.transforms.Resample(sample_rate, 22050)
|
172 |
-
downsample=downsample.to(device='cuda')
|
173 |
cos = torch.nn.CosineSimilarity()
|
174 |
mrstft = auraloss.perceptual.FIRFilter()#auraloss.time.SISDRLoss()#torch.nn.functional.l1_loss
|
175 |
-
mrstft.to(device='cuda')
|
176 |
|
177 |
waveform_loss = losses.L1Loss()
|
178 |
stft_loss = losses.MultiScaleSTFTLoss()
|
|
|
119 |
waveform=inputs['input_values'][0]
|
120 |
#Andy removed: wandb.log({f"unperturbed {audio_name}": wandb.Audio(waveform[0].detach().numpy().flatten(), sample_rate=sample_rate)}, step=0)
|
121 |
waveform=torch.reshape(waveform, (1, waveform.shape[0], waveform.shape[1]))
|
122 |
+
#Andy removed: waveform=waveform.to(device='cuda')
|
123 |
+
#Andy edited: inputs["padding_mask"]=inputs["padding_mask"].to(device='cuda')
|
124 |
+
inputs["padding_mask"]=inputs["padding_mask"]
|
125 |
|
126 |
if method=="encoder":
|
127 |
unperturbed_waveform=waveform.clone().detach()
|
|
|
136 |
style_waveform=style_inputs['input_values'][0]
|
137 |
#Andy removed: wandb.log({f"transfer style": wandb.Audio(style_waveform[0].detach().numpy().flatten(), sample_rate=sample_rate)}, step=0)
|
138 |
style_waveform=torch.reshape(style_waveform, (1, style_waveform.shape[0], style_waveform.shape[1]))
|
139 |
+
#Andy edited: style_waveform=style_waveform.to(device='cuda')
|
140 |
+
style_waveform=style_waveform
|
141 |
+
#Andy edited: style_inputs["padding_mask"]=style_inputs["padding_mask"].to(device='cuda')
|
142 |
+
style_inputs["padding_mask"]=style_inputs["padding_mask"]
|
143 |
# unperturbed_latent=encoder(waveform, inputs["padding_mask"]).audio_values.detach()
|
144 |
unperturbed_waveform=style_waveform.clone().detach()
|
145 |
unperturbed_latents=[]
|
|
|
151 |
|
152 |
|
153 |
noise=torch.normal(torch.zeros(waveform.shape), 0.0)
|
154 |
+
#Andy removed: noise=noise.to(device='cuda')
|
155 |
noise.requires_grad=True
|
156 |
# waveform=torch.nn.parameter.Parameter(waveform)
|
157 |
|
|
|
172 |
# )
|
173 |
|
174 |
downsample = torchaudio.transforms.Resample(sample_rate, 22050)
|
175 |
+
#Andy removed: downsample=downsample.to(device='cuda')
|
176 |
cos = torch.nn.CosineSimilarity()
|
177 |
mrstft = auraloss.perceptual.FIRFilter()#auraloss.time.SISDRLoss()#torch.nn.functional.l1_loss
|
178 |
+
#Andy removed: mrstft.to(device='cuda')
|
179 |
|
180 |
waveform_loss = losses.L1Loss()
|
181 |
stft_loss = losses.MultiScaleSTFTLoss()
|