ProArd / attacks /apgd_ce.py
smi08's picture
Upload folder using huggingface_hub
7771996 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from adv_lib.attacks import carlini_wagner_linf
import torch.optim as optim
from autoattack import AutoAttack
import numpy as np
import logging
from .base import Attack,LabelMixin
from typing import List, Union,Dict
import torch
import torch.nn as nn
from typing import Dict
from .utils import ctx_noparamgrad_and_eval
from utils.distributed import DistributedMetric
from tqdm import tqdm
from torchpack import distributed as dist
from utils import accuracy
class Autoattack_apgd_ce(Attack, LabelMixin):
def __init__(
self, predict, loss_fn=None, eps=0.3, nb_iter=40, eps_iter=0.01, rand_init=True, clip_min=0., clip_max=1.,
ord=np.inf, targeted=False, rand_init_type='uniform'):
super(Autoattack_apgd_ce, self).__init__(predict, loss_fn, clip_min, clip_max)
self.eps = eps
self.nb_iter = nb_iter
self.eps_iter = eps_iter
self.rand_init = rand_init
self.rand_init_type = rand_init_type
self.ord = ord
self.targeted = targeted
if self.loss_fn is None:
self.loss_fn = nn.CrossEntropyLoss(reduction="sum")
self.adversary = AutoAttack(predict, norm='Linf', eps=self.eps, version='standard')
def perturb(self, x, y=None):
self.adversary.attacks_to_run=['apgd-ce']
adversarial_examples = self.adversary.run_standard_evaluation(x, y, bs=100)
return adversarial_examples,adversarial_examples
def eval_AutoAttack_apgd_ce(self,data_loader_dict: Dict)-> Dict:
test_criterion = nn.CrossEntropyLoss().cuda()
val_loss = DistributedMetric()
val_top1 = DistributedMetric()
val_top5 = DistributedMetric()
val_advloss = DistributedMetric()
val_advtop1 = DistributedMetric()
val_advtop5 = DistributedMetric()
self.predict.eval()
with tqdm(
total=len(data_loader_dict["val"]),
desc="Eval",
disable=not dist.is_master(),
) as t:
for images, labels in data_loader_dict["val"]:
images, labels = images.cuda(), labels.cuda()
# compute output
output = self.predict(images)
loss = test_criterion(output, labels)
val_loss.update(loss, images.shape[0])
acc1, acc5 = accuracy(output, labels, topk=(1, 5))
val_top5.update(acc5[0], images.shape[0])
val_top1.update(acc1[0], images.shape[0])
with ctx_noparamgrad_and_eval(self.predict):
images_adv,_ = self.perturb(images, labels)
output_adv = self.predict(images_adv)
loss_adv = test_criterion(output_adv,labels)
val_advloss.update(loss_adv, images.shape[0])
acc1_adv, acc5_adv = accuracy(output_adv, labels, topk=(1, 5))
val_advtop1.update(acc1_adv[0], images.shape[0])
val_advtop5.update(acc5_adv[0], images.shape[0])
t.set_postfix(
{
"loss": val_loss.avg.item(),
"top1": val_top1.avg.item(),
"top5": val_top5.avg.item(),
"adv_loss": val_advloss.avg.item(),
"adv_top1": val_advtop1.avg.item(),
"adv_top5": val_advtop5.avg.item(),
"#samples": val_top1.count.item(),
"batch_size": images.shape[0],
"img_size": images.shape[2],
}
)
t.update()
val_results = {
"val_top1": val_top1.avg.item(),
"val_top5": val_top5.avg.item(),
"val_loss": val_loss.avg.item(),
"val_advtop1": val_advtop1.avg.item(),
"val_advtop5": val_advtop5.avg.item(),
"val_advloss": val_advloss.avg.item(),
}
return val_results