Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torch.nn as nn | |
from torchattacks.attack import Attack | |
class PGA(Attack): | |
r""" | |
Projected Gradient Ascent. | |
[https://arxiv.org/abs/1706.06083] | |
""" | |
def __init__( | |
self, | |
model, | |
alpha=20.0, | |
steps=10, | |
eps=100, | |
relative_alpha=False, | |
self_explain=False, | |
use_cross_entropy_loss=False, | |
pnorm=2, | |
clip_min=-1.0, | |
clip_max=1.0, | |
clip_margin=0.0, | |
eps_for_division=1e-20, | |
): | |
super().__init__("PGA", model) | |
self.eps = eps | |
self.alpha = alpha | |
self.steps = steps | |
self.clip_margin = clip_margin | |
self.clip_min = clip_min | |
self.clip_max = clip_max | |
self.eps_for_division = eps_for_division | |
self.supported_mode = ["default", "targeted"] | |
self.use_cross_entropy_loss = use_cross_entropy_loss | |
self.ce_loss = nn.CrossEntropyLoss() | |
self.pnorm = pnorm | |
self.relative_alpha = relative_alpha | |
self.self_explain = self_explain | |
def compute_loss(self, outputs, target): | |
if self.self_explain: | |
return 0.5 * (outputs**2).sum() | |
if self.use_cross_entropy_loss: | |
return -self.ce_loss(outputs, target) | |
else: | |
return outputs.flatten(1)[torch.arange(len(target)), target].sum() | |
def clip_images_(self, images): | |
if self.clip_margin is not None: | |
return torch.clamp_( | |
images, | |
min=self.clip_min - self.clip_margin, | |
max=self.clip_max + self.clip_margin, | |
).detach() | |
def forward(self, images, labels): | |
r""" | |
Overridden. | |
""" | |
images = images.clone().detach().to(self.device) | |
labels = labels.clone().detach().to(self.device) | |
if self.targeted: | |
target_labels = self.get_target_label(images, labels) | |
adv_images = images.clone().detach() | |
for _ in range(self.steps): | |
adv_images.requires_grad = True | |
outputs = self.get_logits(adv_images) | |
# Calculate loss | |
if self.targeted: | |
cost = self.compute_loss(outputs, target_labels) | |
else: | |
cost = -self.compute_loss(outputs, labels) | |
# Update adversarial images | |
grad = torch.autograd.grad( | |
cost, adv_images, retain_graph=False, create_graph=False | |
)[0] | |
adv_images = adv_images.detach() | |
adv_images_norms = ( | |
torch.norm(adv_images.flatten(1), p=self.pnorm, dim=1) | |
.clamp_min(self.eps_for_division) | |
.view(-1, 1, 1, 1) | |
) | |
grad_norms = ( | |
torch.norm(grad.flatten(1), p=self.pnorm, dim=1) | |
.clamp_min(self.eps_for_division) | |
.view(-1, 1, 1, 1) | |
) | |
if self.alpha is not None: | |
grad = grad / grad_norms | |
if self.relative_alpha: | |
grad = grad * adv_images_norms | |
grad = grad * self.alpha | |
adv_images = adv_images + grad | |
if self.eps is not None: | |
delta = adv_images - images | |
delta_norms = torch.norm( | |
delta.flatten(1), p=self.pnorm, dim=1 | |
).clamp_min(self.eps_for_division) | |
factor = self.eps / delta_norms | |
factor = torch.min(factor, torch.ones_like(delta_norms)) | |
delta = delta * factor.view(-1, 1, 1, 1) | |
adv_images = images + delta | |
self.clip_images_(adv_images) | |
return adv_images, grad | |