AdversarialArt / app.py
will33am's picture
update
3ea0526
import torch
import torch.nn as nn
from robustness.datasets import ImageNet
from robustness.attacker import AttackerModel
from timm.models import create_model
from torchvision import transforms
from robustness.tools.label_maps import CLASS_DICT
from src.utils import *
from torchvision import transforms
import gradio as gr
import os
from PIL import Image
DICT_CLASSES = {'lake':955,
'castle':483,
'library':624,
'dog':235,
'cat':285,
'people':842 #trunks
}
IMG_MAX_SIZE = 256
ARCH = 'crossvit_18_dagger_408'
ARCH_PATH = './checkpoints/robust_crossvit_18_dagger_408.pt'
CUSTOM_TRANSFORMS = transforms.Compose([transforms.Resize([IMG_MAX_SIZE,IMG_MAX_SIZE]),
transforms.ToTensor()])
DEVICE = 'cuda'
def load_model(robust = True):
test_image = Image.open('samples/test.png')
ds = CustomArt(test_image,CUSTOM_TRANSFORMS)
model = create_model(ARCH,pretrained = True).to(DEVICE)
if robust:
print("Load Robust Model")
checkpoint = torch.load(ARCH_PATH,map_location = DEVICE)
model.load_state_dict(checkpoint['state_dict'],strict = True)
model = RobustModel(model).to(DEVICE)
model = AttackerModel(model, ds).to(DEVICE)
model = model.eval()
del test_image,ds
return model
def gradio_fn(image_input,radio_steps,radio_class,radio_robust):
model = load_model(radio_robust)
kwargs = {
'constraint':'2', # L2 attack
'eps': 300,
'step_size': 1,
'iterations': int(radio_steps),
'targeted': True,
'do_tqdm': True,
'device': DEVICE
}
# Define the target and the image
target = torch.tensor([int(DICT_CLASSES[radio_class])]).to(DEVICE)
image = Image.fromarray(image_input)
image = CUSTOM_TRANSFORMS(image).to(DEVICE)
image = torch.unsqueeze(image, dim=0)
_, im_adv = model(image, target, make_adv=True, **kwargs)
im_adv = im_adv.squeeze(dim = 0).permute(1,2,0).cpu().numpy()
return im_adv
if __name__ == '__main__':
demo = gr.Blocks()
with demo:
gr.Markdown("# Art Adversarial Attack")
with gr.Row():
with gr.Column():
with gr.Row():
# Radio Steps Adversarial attack
radio_steps = gr.Radio([10,500,1000,1500,2000],value = 500,label="# Attack Steps")
# Radio Targeted attack
radio_class = gr.Radio(list(DICT_CLASSES.keys()),
value = list(DICT_CLASSES.keys())[0],
label="Target Class")
radio_robust = gr.Radio([True,False],value = True,label="Robust Model")
# Image
with gr.Row():
image_input = gr.Image(label="Input Image")
with gr.Row():
calculate_button = gr.Button("Compute")
with gr.Column():
target_image = gr.Image(label="Art Image")
calculate_button.click(fn = gradio_fn,
inputs = [image_input,radio_steps,radio_class,radio_robust],
outputs = target_image)
demo.launch(debug = True)