dzy7e commited on
Commit
49d1787
·
1 Parent(s): b8c9f75
Files changed (7) hide show
  1. app.py +49 -0
  2. attack.py +113 -0
  3. attacker/FGSM.py +48 -0
  4. attacker/PGD.py +84 -0
  5. attacker/__init__.py +3 -0
  6. attacker/base.py +33 -0
  7. requirements.txt +12 -0
app.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from torchvision.utils import save_image
3
+ from attack import Attacker
4
+ import argparse
5
+
6
+
7
+ def do_attack(img, eps, step_size, steps, progress=gr.Progress()):
8
+ args=argparse.Namespace()
9
+ args.out_dir='./'
10
+ args.target='auto'
11
+ args.eps=eps
12
+ args.step_size=step_size
13
+ args.steps=steps
14
+ args.test_atk=False
15
+
16
+ step = progress.tqdm(range(steps))
17
+
18
+ def pdg_prog(ori_images, images, labels):
19
+ step.update(1)
20
+
21
+ attacker = Attacker(args, pgd_callback=pdg_prog)
22
+ atk_img, noise = attacker.attack_(img)
23
+ attacker.save_image(atk_img, noise, 'out.png')
24
+ return 'out.png'
25
+
26
+ with gr.Blocks(title="Anime AI Detect Fucker Demo", theme="dark") as demo:
27
+ gr.HTML('<a href="https://github.com/7eu7d7/anime-ai-detect-fucker">github repo</a>')
28
+
29
+ with gr.Row():
30
+ eps = gr.Slider(label="eps (Noise intensity)", minimum=1, maximum=16, step=1, value=1)
31
+ step_size = gr.Slider(label="Noise step size", minimum=0.001, maximum=16, step=0.001, value=0.136)
32
+ with gr.Row():
33
+ steps = gr.Slider(label="step count", minimum=1, maximum=100, step=1, value=20)
34
+ model_name = gr.Dropdown(label="attack target",
35
+ choices=["auto", "human", "ai"],
36
+ value="auto", show_label=True)
37
+
38
+ input_image = gr.Image(label="Clean Image", type="pil")
39
+
40
+ atk_btn = gr.Button("Attack")
41
+
42
+ with gr.Column():
43
+ output_image = gr.Image(label="Attacked Image")
44
+
45
+ atk_btn.click(fn=do_attack,
46
+ inputs=[input_image, eps, step_size, steps],
47
+ outputs=output_image)
48
+
49
+ demo.launch()
attack.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ from transformers import BeitFeatureExtractor, BeitForImageClassification
4
+ from PIL import Image
5
+
6
+ from torchvision.utils import save_image
7
+ import torch.nn.functional as F
8
+ from torchvision import transforms
9
+
10
+ from attacker import *
11
+ from torch.nn import CrossEntropyLoss
12
+
13
+ import argparse
14
+
15
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+
17
+ def make_args():
18
+ parser = argparse.ArgumentParser(description='PyTorch MS_COCO Training')
19
+
20
+ parser.add_argument('inputs', type=str)
21
+ parser.add_argument('--out_dir', type=str, default='./output')
22
+ parser.add_argument('--target', type=str, default='auto', help='[auto, ai, human]')
23
+ parser.add_argument('--eps', type=float, default=8/8, help='Noise intensity ')
24
+ parser.add_argument('--step_size', type=float, default=1.087313/8, help='Attack step size')
25
+ parser.add_argument('--steps', type=int, default=20, help='Attack step count')
26
+
27
+ parser.add_argument('--test_atk', action='store_true')
28
+
29
+ return parser.parse_args()
30
+
31
+ class Attacker:
32
+ def __init__(self, args, pgd_callback):
33
+ self.args=args
34
+ os.makedirs(args.out_dir, exist_ok=True)
35
+
36
+ print('正在加载模型...')
37
+ self.feature_extractor = BeitFeatureExtractor.from_pretrained('saltacc/anime-ai-detect')
38
+ self.model = BeitForImageClassification.from_pretrained('saltacc/anime-ai-detect').cuda()
39
+ print('加载完毕')
40
+
41
+ if args.target=='ai': #攻击成被识别为AI
42
+ self.target = torch.tensor([1]).to(device)
43
+ elif args.target=='human':
44
+ self.target = torch.tensor([0]).to(device)
45
+
46
+ dataset_mean_t = torch.tensor([0.5, 0.5, 0.5]).view(1, -1, 1, 1).cuda()
47
+ dataset_std_t = torch.tensor([0.5, 0.5, 0.5]).view(1, -1, 1, 1).cuda()
48
+ self.pgd = PGD(self.model, img_transform=(lambda x: (x - dataset_mean_t) / dataset_std_t, lambda x: x * dataset_std_t + dataset_mean_t))
49
+ self.pgd.set_para(eps=(args.eps * 2) / 255, alpha=lambda: (args.step_size * 2) / 255, iters=args.steps)
50
+ self.pgd.set_loss(CrossEntropyLoss())
51
+ self.pgd.set_call_back(pgd_callback)
52
+
53
+ def save_image(self, image, noise, img_name):
54
+ # 缩放图片只缩放噪声
55
+ W, H = image.size
56
+ noise = F.interpolate(noise, size=(H, W), mode='bicubic')
57
+ img_save = transforms.ToTensor()(image) + noise
58
+ save_image(img_save, os.path.join(self.args.out_dir, f'{img_name[:img_name.rfind(".")]}_atk.png'))
59
+
60
+ def attack_(self, image):
61
+ inputs = self.feature_extractor(images=image, return_tensors="pt")['pixel_values'].cuda()
62
+
63
+ if self.args.target == 'auto':
64
+ with torch.no_grad():
65
+ outputs = self.model(inputs)
66
+ logits = outputs.logits
67
+ cls = logits.argmax(-1).item()
68
+ target = torch.tensor([cls]).to(device)
69
+ else:
70
+ target = self.target
71
+
72
+ if self.args.test_atk:
73
+ self.test_image(inputs, 'before attack')
74
+
75
+ atk_img = self.pgd.attack(inputs, target)
76
+
77
+ noise = self.pgd.img_transform[1](atk_img).detach().cpu() - self.pgd.img_transform[1](inputs).detach().cpu()
78
+
79
+ if self.args.test_atk:
80
+ self.test_image(atk_img, 'after attack')
81
+
82
+ return atk_img, noise
83
+
84
+ def attack_one(self, path):
85
+ image = Image.open(path).convert('RGB')
86
+ atk_img, noise = self.attack_(image)
87
+ self.save_image(image, noise, os.path.basename(path))
88
+
89
+ def attack(self, path):
90
+ count=0
91
+ if os.path.isdir(path):
92
+ img_list=[os.path.join(path, x) for x in os.listdir(path)]
93
+ for img in img_list:
94
+ if (img.lower().endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff'))):
95
+ self.attack_one(img)
96
+ count+=1
97
+ else:
98
+ if (path.lower().endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff'))):
99
+ self.attack_one(path)
100
+ count += 1
101
+ print(f'总共攻击{count}张图像')
102
+
103
+ @torch.no_grad()
104
+ def test_image(self, img, pre_fix):
105
+ outputs = self.model(img)
106
+ logits = outputs.logits
107
+ predicted_class_idx = logits.argmax(-1).item()
108
+ print(pre_fix, "class:", self.model.config.id2label[predicted_class_idx], 'logits:', logits)
109
+
110
+ if __name__ == '__main__':
111
+ args=make_args()
112
+ attacker = Attacker(args)
113
+ attacker.attack(args.inputs)
attacker/FGSM.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from copy import deepcopy
4
+ from .base import Attacker
5
+ from torch.cuda import amp
6
+
7
+ class FGSM(Attacker):
8
+ def __init__(self, model, img_transform=(lambda x:x, lambda x:x), use_amp=False):
9
+ super().__init__(model, img_transform)
10
+ self.use_amp=use_amp
11
+
12
+ if use_amp:
13
+ self.scaler = amp.GradScaler()
14
+
15
+ def set_para(self, eps=8, alpha=lambda:8, **kwargs):
16
+ super().set_para(eps=eps, alpha=alpha, **kwargs)
17
+
18
+ def step(self, images, labels, loss):
19
+ with amp.autocast(enabled=self.use_amp):
20
+ images.requires_grad = True
21
+ outputs = self.model(images).logits
22
+
23
+ self.model.zero_grad()
24
+ cost = loss(outputs, labels)
25
+
26
+ if self.use_amp:
27
+ self.scaler.scale(cost).backward()
28
+ else:
29
+ cost.backward()
30
+
31
+ adv_images = (images + self.alpha() * images.grad.sign()).detach_()
32
+ eta = torch.clamp(adv_images - self.ori_images, min=-self.eps, max=self.eps)
33
+ images = self.img_transform[0](torch.clamp(self.img_transform[1](self.ori_images + eta), min=0, max=255).detach_())
34
+
35
+ return images
36
+
37
+ def attack(self, images, labels):
38
+ #images = deepcopy(images)
39
+ #self.ori_images = deepcopy(images)
40
+
41
+ self.model.eval()
42
+
43
+ images = self.forward(self, images, labels)
44
+
45
+ self.model.zero_grad()
46
+ self.model.train()
47
+
48
+ return images
attacker/PGD.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from copy import deepcopy
4
+ from .base import Attacker, Empty
5
+ from torch.cuda import amp
6
+ from tqdm import tqdm
7
+
8
+ class PGD(Attacker):
9
+ def __init__(self, model, img_transform=(lambda x:x, lambda x:x), use_amp=False):
10
+ super().__init__(model, img_transform)
11
+ self.use_amp=use_amp
12
+ self.call_back=None
13
+ self.img_loader=None
14
+ self.img_hook=None
15
+
16
+ self.scaler = amp.GradScaler(enabled=use_amp)
17
+
18
+ def set_para(self, eps=8, alpha=lambda:8, iters=20, **kwargs):
19
+ super().set_para(eps=eps, alpha=alpha, iters=iters, **kwargs)
20
+
21
+ def set_call_back(self, call_back):
22
+ self.call_back=call_back
23
+
24
+ def set_img_loader(self, img_loader):
25
+ self.img_loader=img_loader
26
+
27
+ def step(self, images, labels, loss):
28
+ with amp.autocast(enabled=self.use_amp):
29
+ images.requires_grad = True
30
+ outputs = self.model(images).logits
31
+
32
+ self.model.zero_grad()
33
+ cost = loss(outputs, labels)#+outputs[2].view(-1)[0]*0+outputs[1].view(-1)[0]*0+outputs[0].view(-1)[0]*0 #support DDP
34
+
35
+ self.scaler.scale(cost).backward()
36
+
37
+ adv_images = (images + self.alpha() * images.grad.sign()).detach_()
38
+ eta = torch.clamp(adv_images - self.ori_images, min=-self.eps, max=self.eps)
39
+ images = self.img_transform[0](torch.clamp(self.img_transform[1](self.ori_images + eta), min=0, max=1).detach_())
40
+
41
+ return images
42
+
43
+ def set_data(self, images, labels):
44
+ self.ori_images = deepcopy(images)
45
+ self.images = images
46
+ self.labels = labels
47
+
48
+ def __iter__(self):
49
+ self.atk_step=0
50
+ return self
51
+
52
+ def __next__(self):
53
+ self.atk_step += 1
54
+ if self.atk_step>self.iters:
55
+ raise StopIteration
56
+
57
+ with self.model.no_sync() if isinstance(self.model, nn.parallel.DistributedDataParallel) else Empty():
58
+ self.model.eval()
59
+
60
+ self.images = self.forward(self, self.images, self.labels)
61
+
62
+ self.model.zero_grad()
63
+ self.model.train()
64
+
65
+ return self.ori_images, self.images.detach(), self.labels
66
+
67
+ def attack(self, images, labels):
68
+ #images = deepcopy(images)
69
+ self.ori_images = deepcopy(images)
70
+
71
+ for i in tqdm(range(self.iters)):
72
+ self.model.eval()
73
+
74
+ images = self.forward(self, images, labels)
75
+
76
+ self.model.zero_grad()
77
+ self.model.train()
78
+ if self.call_back:
79
+ self.call_back(self.ori_images, images.detach(), labels)
80
+
81
+ if self.img_hook is not None:
82
+ images=self.img_hook(self.ori_images, images.detach())
83
+
84
+ return images
attacker/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .base import *
2
+ from .PGD import *
3
+ from .FGSM import *
attacker/base.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ class Attacker:
3
+ def __init__(self, model, img_transform=(lambda x:x, lambda x:x)):
4
+ self.model = model # 必须是pytorch的model
5
+ '''self.model.eval()
6
+ for k, v in self.model.named_parameters():
7
+ v.requires_grad = False'''
8
+ self.img_transform=img_transform
9
+ self.forward = lambda attacker, images, labels: attacker.step(images, labels, attacker.loss)
10
+
11
+ def set_para(self, **kwargs):
12
+ for k,v in kwargs.items():
13
+ setattr(self, k,v)
14
+
15
+ def set_forward(self, forward):
16
+ self.forward=forward
17
+
18
+ def step(self, images, labels, loss):
19
+ pass
20
+
21
+ def set_loss(self, loss):
22
+ self.loss=loss
23
+
24
+ def attack(self, images, labels):
25
+ pass
26
+
27
+
28
+ class Empty:
29
+ def __enter__(self):
30
+ pass
31
+
32
+ def __exit__(self, exc_type, exc_val, exc_tb):
33
+ pass
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==1.12.1
2
+ torchvision==0.13.1
3
+ timm==0.6.12
4
+ Pillow
5
+ blobfile
6
+ mypy
7
+ numpy
8
+ pytest
9
+ requests
10
+ einops
11
+ deepspeed==0.4.0
12
+ scipy