onehowon commited on
Commit
4b3aa59
ยท
verified ยท
1 Parent(s): c10f63e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +147 -77
app.py CHANGED
@@ -3,114 +3,184 @@ import torch
3
  import torch.nn as nn
4
  import torch.optim as optim
5
  from torchvision import transforms, models
6
- from art.attacks.evasion import FastGradientMethod
 
 
 
 
7
  from art.estimators.classification import PyTorchClassifier
8
- from PIL import Image
9
  import numpy as np
10
  import os
11
- import io
12
  from blind_watermark import WaterMark
 
 
13
 
14
- # Pretrained ResNet50 ๋ชจ๋ธ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ (ImageNet ์‚ฌ์ „ ํ›ˆ๋ จ)
15
- model = models.resnet50(pretrained=True)
16
-
17
- # CIFAR-10์— ๋งž์ถฐ ๋งˆ์ง€๋ง‰ ๋ถ„๋ฅ˜ ๋ ˆ์ด์–ด ์ˆ˜์ •
18
- num_ftrs = model.fc.in_features
19
- model.fc = nn.Linear(num_ftrs, 10)
20
 
21
- # ๋ชจ๋ธ์„ GPU๋กœ ์ด๋™
22
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
- model = model.to(device)
 
24
 
25
- # ์†์‹ค ํ•จ์ˆ˜์™€ ์˜ตํ‹ฐ๋งˆ์ด์ € ์„ค์ •
26
  criterion = nn.CrossEntropyLoss()
27
- optimizer = optim.Adam(model.parameters(), lr=0.001)
 
 
 
 
 
 
 
 
 
28
 
29
- # PyTorchClassifier ์ƒ์„ฑ
30
- classifier = PyTorchClassifier(
31
- model=model,
32
  loss=criterion,
33
- optimizer=optimizer,
34
- input_shape=(3, 64, 64),
35
  nb_classes=10,
36
  )
37
 
38
- # ์ด๋ฏธ์ง€ ์ „์ฒ˜๋ฆฌ ํ•จ์ˆ˜
 
 
 
 
39
  def preprocess_image(image):
40
  transform = transforms.Compose([
 
41
  transforms.ToTensor(),
42
- transforms.Normalize(mean=[0.485, 0.456, 0.406],
43
- std=[0.229, 0.224, 0.225])
44
  ])
45
- return transform(image).unsqueeze(0).to(device)
46
-
47
- # FGSM ๊ณต๊ฒฉ ์ ์šฉ ๋ฐ ์ด๋ฏธ์ง€ ์ฒ˜๋ฆฌ ํ•จ์ˆ˜
48
- def generate_adversarial_image(image, eps_value):
49
- img_tensor = preprocess_image(image)
50
-
51
- # FGSM ๊ณต๊ฒฉ ์„ค์ •
52
- attack = FastGradientMethod(estimator=classifier, eps=eps_value)
53
-
54
- # ์ ๋Œ€์  ์˜ˆ์ œ ์ƒ์„ฑ
55
- adv_img_tensor = attack.generate(x=img_tensor.cpu().numpy())
56
- adv_img_tensor = torch.tensor(adv_img_tensor).to(device)
57
 
58
- # ์ ๋Œ€์  ์ด๋ฏธ์ง€ ๋ณ€ํ™˜
59
- adv_img_np = adv_img_tensor.squeeze(0).cpu().numpy()
60
  mean = np.array([0.485, 0.456, 0.406])
61
  std = np.array([0.229, 0.224, 0.225])
62
  adv_img_np = (adv_img_np * std[:, None, None]) + mean[:, None, None]
63
  adv_img_np = np.clip(adv_img_np, 0, 1)
64
  adv_img_np = adv_img_np.transpose(1, 2, 0)
65
-
66
- # PIL ์ด๋ฏธ์ง€๋กœ ๋ณ€ํ™˜
67
- adv_image_pil = Image.fromarray((adv_img_np * 255).astype(np.uint8))
68
-
69
  return adv_image_pil
70
 
71
- # ์›Œํ„ฐ๋งˆํฌ ์‚ฝ์ž… ํ•จ์ˆ˜
72
- def apply_watermark(image_pil, wm_text="ํ…์ŠคํŠธ ์‚ฝ์ž…", password_img=000, password_wm=000):
73
- bwm = WaterMark(password_img=password_img, password_wm=password_wm)
74
-
75
- # ์ด๋ฏธ์ง€ ๋ฐ”์ดํŠธ ๋ฐ์ดํ„ฐ๋ฅผ ์ž„์‹œ ํŒŒ์ผ๋กœ ์ €์žฅ
76
- temp_image_path = "temp_image.png"
77
- image_pil.save(temp_image_path)
78
 
79
- # temp_image_path ๊ฒฝ๋กœ๋กœ ์›Œํ„ฐ๋งˆํฌ ์‚ฝ์ž… ์ฒ˜๋ฆฌ
80
- bwm.read_img(temp_image_path)
81
- bwm.read_wm(wm_text, mode='str')
82
-
83
- # ์›Œํ„ฐ๋งˆํฌ ์‚ฝ์ž…
84
- output_path = "watermarked_image.png"
85
- bwm.embed(output_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
- # ์‚ฝ์ž…๋œ ์›Œํ„ฐ๋งˆํฌ ์ด๋ฏธ์ง€ ํŒŒ์ผ์„ ๋‹ค์‹œ ์ฝ์–ด์„œ PIL ์ด๋ฏธ์ง€๋กœ ๋ณ€ํ™˜
88
- result_image = Image.open(output_path)
 
 
 
89
 
90
- # ์ž„์‹œ ํŒŒ์ผ ์‚ญ์ œ
91
- os.remove(temp_image_path)
92
- os.remove(output_path)
 
93
 
94
- return result_image
 
 
95
 
96
- # ์ „์ฒด ์ด๋ฏธ์ง€ ์ฒ˜๋ฆฌ ํ•จ์ˆ˜
97
- def process_image(image, eps_value, wm_text, password_img, password_wm):
98
- # ์ ๋Œ€์  ์ด๋ฏธ์ง€ ์ƒ์„ฑ
99
- adv_image = generate_adversarial_image(image, eps_value)
100
-
101
- # ์ ๋Œ€์  ์ด๋ฏธ์ง€์— ์›Œํ„ฐ๋งˆํฌ ์‚ฝ์ž…
102
- watermarked_image = apply_watermark(adv_image, wm_text, int(password_img), int(password_wm))
103
 
104
- return watermarked_image
 
 
 
105
 
106
- # Gradio ์ธํ„ฐํŽ˜์ด์Šค ์ •์˜
107
- gr.Interface(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  fn=process_image,
109
- inputs=[gr.Image(type="pil", label="์ด๋ฏธ์ง€๋ฅผ ์—…๋กœ๋“œํ•˜์„ธ์š”"), # ์ด๋ฏธ์ง€ ์—…๋กœ๋“œ ํ•„๋“œ
110
- gr.Slider(0.1, 1.0, step=0.1, value=0.3, label="Epsilon ๊ฐ’ ์„ค์ • (๋…ธ์ด์ฆˆ ๊ฐ•๋„)"), # epsilon ๊ฐ’ ์Šฌ๋ผ์ด๋”
111
- gr.Textbox(label="์›Œํ„ฐ๋งˆํฌ ํ…์ŠคํŠธ ์ž…๋ ฅ", value="ํ…์ŠคํŠธ ์‚ฝ์ž…"), # ์›Œํ„ฐ๋งˆํฌ ํ…์ŠคํŠธ ์ž…๋ ฅ ํ•„๋“œ
112
- gr.Number(label="์ด๋ฏธ์ง€ ๋น„๋ฐ€๋ฒˆํ˜ธ", value=000), # ์ด๋ฏธ์ง€ ๋น„๋ฐ€๋ฒˆํ˜ธ ์ž…๋ ฅ ํ•„๋“œ
113
- gr.Number(label="์›Œํ„ฐ๋งˆํฌ ๋น„๋ฐ€๋ฒˆํ˜ธ", value=000) # ์›Œํ„ฐ๋งˆํฌ ๋น„๋ฐ€๋ฒˆํ˜ธ ์ž…๋ ฅ ํ•„๋“œ
 
 
 
114
  ],
115
- outputs=gr.Image(type="pil", label="์›Œํ„ฐ๋งˆํฌ๊ฐ€ ์‚ฝ์ž…๋œ ์ด๋ฏธ์ง€") # ๊ฒฐ๊ณผ ์ด๋ฏธ์ง€ ์ถœ๋ ฅ ํ•„๋“œ
116
- ).launch()
 
 
 
 
 
 
 
 
 
3
  import torch.nn as nn
4
  import torch.optim as optim
5
  from torchvision import transforms, models
6
+ from art.attacks.evasion import (
7
+ FastGradientMethod, CarliniL2Method, DeepFool, AutoAttack,
8
+ ProjectedGradientDescent, BasicIterativeMethod, SpatialTransformation,
9
+ MomentumIterativeMethod, SaliencyMapMethod, NewtonFool
10
+ )
11
  from art.estimators.classification import PyTorchClassifier
12
+ from PIL import Image, ImageOps
13
  import numpy as np
14
  import os
 
15
  from blind_watermark import WaterMark
16
+ from torchvision.models import resnet50, vgg16, ResNet50_Weights, VGG16_Weights
17
+ import tempfile
18
 
19
+ resnet_model = resnet50(weights=ResNet50_Weights.DEFAULT)
20
+ num_ftrs_resnet = resnet_model.fc.in_features
21
+ resnet_model.fc = nn.Linear(num_ftrs_resnet, 10)
22
+ resnet_model = resnet_model.to("cuda" if torch.cuda.is_available() else "cpu")
 
 
23
 
24
+ vgg_model = vgg16(weights=VGG16_Weights.DEFAULT)
25
+ num_ftrs_vgg = vgg_model.classifier[6].in_features
26
+ vgg_model.classifier[6] = nn.Linear(num_ftrs_vgg, 10)
27
+ vgg_model = vgg_model.to("cuda" if torch.cuda.is_available() else "cpu")
28
 
 
29
  criterion = nn.CrossEntropyLoss()
30
+ optimizer_resnet = optim.Adam(resnet_model.parameters(), lr=0.001)
31
+ optimizer_vgg = optim.Adam(vgg_model.parameters(), lr=0.001)
32
+
33
+ resnet_classifier = PyTorchClassifier(
34
+ model=resnet_model,
35
+ loss=criterion,
36
+ optimizer=optimizer_resnet,
37
+ input_shape=(3, 224, 224),
38
+ nb_classes=10,
39
+ )
40
 
41
+ vgg_classifier = PyTorchClassifier(
42
+ model=vgg_model,
 
43
  loss=criterion,
44
+ optimizer=optimizer_vgg,
45
+ input_shape=(3, 224, 224),
46
  nb_classes=10,
47
  )
48
 
49
+ models_dict = {
50
+ "ResNet50": resnet_classifier,
51
+ "VGG16": vgg_classifier
52
+ }
53
+
54
  def preprocess_image(image):
55
  transform = transforms.Compose([
56
+ transforms.Resize((224, 224)),
57
  transforms.ToTensor(),
58
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
 
59
  ])
60
+ return transform(image).unsqueeze(0).to("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
 
 
 
 
 
61
 
62
+ def postprocess_image(tensor, original_size):
63
+ adv_img_np = tensor.squeeze(0).cpu().numpy()
64
  mean = np.array([0.485, 0.456, 0.406])
65
  std = np.array([0.229, 0.224, 0.225])
66
  adv_img_np = (adv_img_np * std[:, None, None]) + mean[:, None, None]
67
  adv_img_np = np.clip(adv_img_np, 0, 1)
68
  adv_img_np = adv_img_np.transpose(1, 2, 0)
69
+ adv_image_pil = Image.fromarray((adv_img_np * 255).astype(np.uint8)).resize(original_size)
 
 
 
70
  return adv_image_pil
71
 
72
+ def generate_adversarial_image(image, model_name, attack_types, eps_value):
73
+ original_size = image.size
74
+ img_tensor = preprocess_image(image)
 
 
 
 
75
 
76
+ classifier = models_dict[model_name]
77
+
78
+ try:
79
+ for attack_type in attack_types:
80
+ if attack_type == "FGSM":
81
+ attack = FastGradientMethod(estimator=classifier, eps=eps_value)
82
+ elif attack_type == "C&W":
83
+ attack = CarliniL2Method(classifier=classifier, confidence=0.05)
84
+ elif attack_type == "DeepFool":
85
+ attack = DeepFool(classifier=classifier, max_iter=20)
86
+ elif attack_type == "AutoAttack":
87
+ attack = AutoAttack(estimator=classifier, eps=eps_value, batch_size=1)
88
+ elif attack_type == "PGD":
89
+ attack = ProjectedGradientDescent(estimator=classifier, eps=eps_value, eps_step=eps_value / 10, max_iter=40)
90
+ elif attack_type == "BIM":
91
+ attack = BasicIterativeMethod(estimator=classifier, eps=eps_value, eps_step=eps_value / 10, max_iter=10)
92
+ elif attack_type == "STA":
93
+ attack = SpatialTransformation(estimator=classifier, max_translation=0.2)
94
+ elif attack_type == "MIM":
95
+ attack = MomentumIterativeMethod(estimator=classifier, eps=eps_value, eps_step=eps_value / 10, max_iter=10)
96
+ elif attack_type == "JSMA":
97
+ attack = SaliencyMapMethod(classifier=classifier, theta=0.1, gamma=0.1)
98
+ elif attack_type == "NewtonFool":
99
+ attack = NewtonFool(classifier=classifier, max_iter=20)
100
+
101
+ adv_img_np = attack.generate(x=img_tensor.cpu().numpy())
102
+ img_tensor = torch.tensor(adv_img_np).to("cuda" if torch.cuda.is_available() else "cpu")
103
+ except Exception as e:
104
+ print(f"Error in adversarial generation: {e}")
105
+ return image
106
+
107
+ adv_image_pil = postprocess_image(img_tensor, original_size)
108
+ return adv_image_pil
109
 
110
+ def apply_watermark(image_pil, wm_text="ํ…์ŠคํŠธ ์‚ฝ์ž…", password_img=0, password_wm=0):
111
+ try:
112
+ bwm = WaterMark(password_img=password_img, password_wm=password_wm)
113
+ temp_image_path = tempfile.mktemp(suffix=".png")
114
+ image_pil.save(temp_image_path, format="PNG")
115
 
116
+ bwm.read_img(temp_image_path)
117
+ bwm.read_wm(wm_text, mode='str')
118
+ output_path = tempfile.mktemp(suffix=".png")
119
+ bwm.embed(output_path)
120
 
121
+ result_image = Image.open(output_path).convert("RGB")
122
+ os.remove(temp_image_path)
123
+ os.remove(output_path)
124
 
125
+ return result_image
126
+ except Exception as e:
127
+ print(f"Error in apply_watermark: {str(e)}")
128
+ return image_pil
 
 
 
129
 
130
+ def extract_watermark(image_pil, password_img=0, password_wm=0):
131
+ bwm = WaterMark(password_img=password_img, password_wm=password_wm)
132
+ temp_image_path = tempfile.mktemp(suffix=".png")
133
+ image_pil.save(temp_image_path, format="PNG")
134
 
135
+ extracted_wm_text = bwm.extract(temp_image_path, wm_shape=(32, 32), mode='str')
136
+ os.remove(temp_image_path)
137
+ return extracted_wm_text
138
+
139
+ def process_image(image, model_name, attack_types, eps_value, wm_text, password_img, password_wm):
140
+ try:
141
+ adv_image = generate_adversarial_image(image, model_name, attack_types, eps_value)
142
+ except Exception as e:
143
+ error_message = f"Error in adversarial generation: {str(e)}"
144
+ return image, error_message, None, None, None
145
+
146
+ try:
147
+ watermarked_image = apply_watermark(adv_image, wm_text, int(password_img), int(password_wm))
148
+ except Exception as e:
149
+ error_message = f"Error in watermarking: {str(e)}"
150
+ return image, adv_image, error_message, None, None
151
+
152
+ try:
153
+ extracted_wm_text = extract_watermark(watermarked_image, int(password_img), int(password_wm))
154
+ except Exception as e:
155
+ error_message = f"Error in watermark extraction: {str(e)}"
156
+ return image, adv_image, watermarked_image, error_message, None
157
+
158
+ output_path = tempfile.mktemp(suffix=".png")
159
+ watermarked_image.save(output_path, format="PNG")
160
+ return image, adv_image, watermarked_image, extracted_wm_text, output_path
161
+
162
+ def download_image_as_png(image_path):
163
+ with open(image_path, "rb") as file:
164
+ return file.read(), "image/png"
165
+
166
+ interface = gr.Interface(
167
  fn=process_image,
168
+ inputs=[
169
+ gr.Image(type="pil", label="์ด๋ฏธ์ง€๋ฅผ ์—…๋กœ๋“œํ•˜์„ธ์š”"),
170
+ gr.Dropdown(choices=["ResNet50", "VGG16"], label="๋ชจ๋ธ ์„ ํƒ"),
171
+ gr.CheckboxGroup(choices=["FGSM", "C&W", "DeepFool", "AutoAttack", "PGD", "BIM", "STA", "MIM", "JSMA", "NewtonFool"], label="๊ณต๊ฒฉ ์œ ํ˜• ์„ ํƒ"),
172
+ gr.Slider(0.001, 0.9, step=0.001, value=0.005, label="Epsilon ๊ฐ’ ์„ค์ • (๋…ธ์ด์ฆˆ ๊ฐ•๋„)"),
173
+ gr.Textbox(label="์›Œํ„ฐ๋งˆํฌ ํ…์ŠคํŠธ ์ž…๋ ฅ", value="ํ…์ŠคํŠธ ์‚ฝ์ž…"),
174
+ gr.Number(label="์ด๋ฏธ์ง€ ๋น„๋ฐ€๋ฒˆํ˜ธ", value=0),
175
+ gr.Number(label="์›Œํ„ฐ๋งˆํฌ ๋น„๋ฐ€๋ฒˆํ˜ธ", value=0)
176
  ],
177
+ outputs=[
178
+ gr.Image(type="numpy", label="์›๋ณธ ์ด๋ฏธ์ง€"),
179
+ gr.Image(type="numpy", label="์ ๋Œ€์  ์ด๋ฏธ์ง€ ์ƒ์„ฑ ๋‹จ๊ณ„"),
180
+ gr.Image(type="numpy", label="์›Œํ„ฐ๋งˆํฌ๊ฐ€ ์‚ฝ์ž…๋œ ์ตœ์ข… ์ด๋ฏธ์ง€"),
181
+ gr.Textbox(label="์ถ”์ถœ๋œ ์›Œํ„ฐ๋งˆํฌ ํ…์ŠคํŠธ"),
182
+ gr.File(label="PNG๋กœ ๋‹ค์šด๋กœ๋“œ")
183
+ ]
184
+ )
185
+
186
+ interface.launch(debug=True, share=True)