Spaces:
Runtime error
Runtime error
import torch | |
from torch import nn | |
from copy import deepcopy | |
from .base import Attacker | |
from torch.cuda import amp | |
class FGSM(Attacker): | |
def __init__(self, model, img_transform=(lambda x:x, lambda x:x), use_amp=False): | |
super().__init__(model, img_transform) | |
self.use_amp=use_amp | |
if use_amp: | |
self.scaler = amp.GradScaler() | |
def set_para(self, eps=8, alpha=lambda:8, **kwargs): | |
super().set_para(eps=eps, alpha=alpha, **kwargs) | |
def step(self, images, labels, loss): | |
with amp.autocast(enabled=self.use_amp): | |
images.requires_grad = True | |
outputs = self.model(images).logits | |
self.model.zero_grad() | |
cost = loss(outputs, labels) | |
if self.use_amp: | |
self.scaler.scale(cost).backward() | |
else: | |
cost.backward() | |
adv_images = (images + self.alpha() * images.grad.sign()).detach_() | |
eta = torch.clamp(adv_images - self.ori_images, min=-self.eps, max=self.eps) | |
images = self.img_transform[0](torch.clamp(self.img_transform[1](self.ori_images + eta), min=0, max=255).detach_()) | |
return images | |
def attack(self, images, labels): | |
#images = deepcopy(images) | |
#self.ori_images = deepcopy(images) | |
self.model.eval() | |
images = self.forward(self, images, labels) | |
self.model.zero_grad() | |
self.model.train() | |
return images |