taquynhnga's picture
Update backend/adversarial_attack.py
932fc3c
raw
history blame
3.65 kB
import PIL
from PIL import Image
import numpy as np
from matplotlib import pylab as P
import cv2
import torch
from torch.utils.data import TensorDataset
from torchvision import transforms
import torch.nn.functional as F
from transformers.image_utils import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from torchvex.base import ExplanationMethod
from torchvex.utils.normalization import clamp_quantile
from backend.utils import load_image, load_model
from backend.smooth_grad import generate_smoothgrad_mask
import streamlit as st
IMAGENET_DEFAULT_MEAN = np.asarray(IMAGENET_DEFAULT_MEAN).reshape([1,3,1,1])
IMAGENET_DEFAULT_STD = np.asarray(IMAGENET_DEFAULT_STD).reshape([1,3,1,1])
def deprocess_image(image_inputs):
return (image_inputs * IMAGENET_DEFAULT_STD + IMAGENET_DEFAULT_MEAN) * 255
def feed_forward(input_image):
model, feature_extractor = load_model('ConvNeXt')
inputs = feature_extractor(input_image, do_resize=False, return_tensors="pt")['pixel_values']
logits = model(inputs).logits
prediction_prob = F.softmax(logits, dim=-1).max() # prediction probability
# prediction class id, start from 1 to 1000 so it needs to +1 in the end
prediction_class = logits.argmax(-1).item()
prediction_label = model.config.id2label[prediction_class] # prediction class label
return prediction_prob, prediction_class, prediction_label
# FGSM attack code
def fgsm_attack(image, epsilon, data_grad):
# Collect the element-wise sign of the data gradient and normalize it
sign_data_grad = torch.gt(data_grad, 0).type(torch.FloatTensor) * 2.0 - 1.0
perturbed_image = image + epsilon*sign_data_grad
return perturbed_image
# perform attack on the model
def perform_attack(input_image, target, epsilon):
model, feature_extractor = load_model("ConvNeXt")
# preprocess input image
inputs = feature_extractor(input_image, do_resize=False, return_tensors="pt")['pixel_values']
inputs.requires_grad = True
# predict
logits = model(inputs).logits
prediction_prob = F.softmax(logits, dim=-1).max()
prediction_class = logits.argmax(-1).item()
prediction_label = model.config.id2label[prediction_class]
# Calculate the loss
loss = F.nll_loss(logits, torch.tensor([target]))
# Zero all existing gradients
model.zero_grad()
# Calculate gradients of model in backward pass
loss.backward()
# Collect datagrad
data_grad = inputs.grad.data
# Call FGSM Attack
perturbed_data = fgsm_attack(inputs, epsilon, data_grad)
# Re-classify the perturbed image
new_prediction = model(perturbed_data).logits
new_pred_prob = F.softmax(new_prediction, dim=-1).max()
new_pred_class = new_prediction.argmax(-1).item()
new_pred_label = model.config.id2label[new_pred_class]
return perturbed_data, new_pred_prob.item(), new_pred_class, new_pred_label
def find_smallest_epsilon(input_image, target):
epsilons = [i*0.001 for i in range(1000)]
for epsilon in epsilons:
perturbed_data, new_prob, new_id, new_label = perform_attack(input_image, target, epsilon)
if new_id != target:
return perturbed_data, new_prob, new_id, new_label, epsilon
return None
# @st.cache_data
@st.cache(allow_output_mutation=True)
def generate_images(image_id, epsilon=0):
model, feature_extractor = load_model("ConvNeXt")
original_image_dict = load_image(image_id)
image = original_image_dict['image']
return generate_smoothgrad_mask(
image, 'ConvNeXt',
model, feature_extractor, num_samples=10, return_mask=True)