Maciej Satkiewicz
app works locally
ed46d32
import torch
import torch.nn as nn
from torchattacks.attack import Attack
class PGA(Attack):
r"""
Projected Gradient Ascent.
[https://arxiv.org/abs/1706.06083]
"""
def __init__(
self,
model,
alpha=20.0,
steps=10,
eps=100,
relative_alpha=False,
self_explain=False,
use_cross_entropy_loss=False,
pnorm=2,
clip_min=-1.0,
clip_max=1.0,
clip_margin=0.0,
eps_for_division=1e-20,
):
super().__init__("PGA", model)
self.eps = eps
self.alpha = alpha
self.steps = steps
self.clip_margin = clip_margin
self.clip_min = clip_min
self.clip_max = clip_max
self.eps_for_division = eps_for_division
self.supported_mode = ["default", "targeted"]
self.use_cross_entropy_loss = use_cross_entropy_loss
self.ce_loss = nn.CrossEntropyLoss()
self.pnorm = pnorm
self.relative_alpha = relative_alpha
self.self_explain = self_explain
def compute_loss(self, outputs, target):
if self.self_explain:
return 0.5 * (outputs**2).sum()
if self.use_cross_entropy_loss:
return -self.ce_loss(outputs, target)
else:
return outputs.flatten(1)[torch.arange(len(target)), target].sum()
def clip_images_(self, images):
if self.clip_margin is not None:
return torch.clamp_(
images,
min=self.clip_min - self.clip_margin,
max=self.clip_max + self.clip_margin,
).detach()
def forward(self, images, labels):
r"""
Overridden.
"""
images = images.clone().detach().to(self.device)
labels = labels.clone().detach().to(self.device)
if self.targeted:
target_labels = self.get_target_label(images, labels)
adv_images = images.clone().detach()
for _ in range(self.steps):
adv_images.requires_grad = True
outputs = self.get_logits(adv_images)
# Calculate loss
if self.targeted:
cost = self.compute_loss(outputs, target_labels)
else:
cost = -self.compute_loss(outputs, labels)
# Update adversarial images
grad = torch.autograd.grad(
cost, adv_images, retain_graph=False, create_graph=False
)[0]
adv_images = adv_images.detach()
adv_images_norms = (
torch.norm(adv_images.flatten(1), p=self.pnorm, dim=1)
.clamp_min(self.eps_for_division)
.view(-1, 1, 1, 1)
)
grad_norms = (
torch.norm(grad.flatten(1), p=self.pnorm, dim=1)
.clamp_min(self.eps_for_division)
.view(-1, 1, 1, 1)
)
if self.alpha is not None:
grad = grad / grad_norms
if self.relative_alpha:
grad = grad * adv_images_norms
grad = grad * self.alpha
adv_images = adv_images + grad
if self.eps is not None:
delta = adv_images - images
delta_norms = torch.norm(
delta.flatten(1), p=self.pnorm, dim=1
).clamp_min(self.eps_for_division)
factor = self.eps / delta_norms
factor = torch.min(factor, torch.ones_like(delta_norms))
delta = delta * factor.view(-1, 1, 1, 1)
adv_images = images + delta
self.clip_images_(adv_images)
return adv_images, grad