File size: 3,648 Bytes
fcc16aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import PIL
from PIL import Image
import numpy as np
from matplotlib import pylab as P
import cv2

import torch
from torch.utils.data import TensorDataset
from torchvision import transforms
import torch.nn.functional as F

from transformers.image_utils import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD

from torchvex.base import ExplanationMethod
from torchvex.utils.normalization import clamp_quantile

from backend.utils import load_image, load_model
from backend.smooth_grad import generate_smoothgrad_mask

import streamlit as st

IMAGENET_DEFAULT_MEAN = np.asarray(IMAGENET_DEFAULT_MEAN).reshape([1,3,1,1])
IMAGENET_DEFAULT_STD = np.asarray(IMAGENET_DEFAULT_STD).reshape([1,3,1,1])

def deprocess_image(image_inputs):
    return (image_inputs * IMAGENET_DEFAULT_STD + IMAGENET_DEFAULT_MEAN) * 255


def feed_forward(input_image):
    model, feature_extractor = load_model('ConvNeXt')
    inputs = feature_extractor(input_image, do_resize=False, return_tensors="pt")['pixel_values']
    logits = model(inputs).logits
    prediction_prob = F.softmax(logits, dim=-1).max()                   # prediction probability
    # prediction class id, start from 1 to 1000 so it needs to +1 in the end
    prediction_class = logits.argmax(-1).item()                 
    prediction_label = model.config.id2label[prediction_class]  # prediction class label
    return prediction_prob, prediction_class, prediction_label

# FGSM attack code
def fgsm_attack(image, epsilon, data_grad):
    # Collect the element-wise sign of the data gradient and normalize it
    sign_data_grad = torch.gt(data_grad, 0).type(torch.FloatTensor) * 2.0 - 1.0
    perturbed_image = image + epsilon*sign_data_grad
    return perturbed_image

# perform attack on the model
def perform_attack(input_image, target, epsilon):
    model, feature_extractor = load_model("ConvNeXt")
    # preprocess input image
    inputs = feature_extractor(input_image, do_resize=False, return_tensors="pt")['pixel_values']
    inputs.requires_grad = True
    
    # predict
    logits = model(inputs).logits
    prediction_prob = F.softmax(logits, dim=-1).max()
    prediction_class = logits.argmax(-1).item()
    prediction_label = model.config.id2label[prediction_class]
    
    # Calculate the loss
    loss = F.nll_loss(logits, torch.tensor([target]))

    # Zero all existing gradients
    model.zero_grad()

    # Calculate gradients of model in backward pass
    loss.backward()
    
    # Collect datagrad
    data_grad = inputs.grad.data
    
    # Call FGSM Attack
    perturbed_data = fgsm_attack(inputs, epsilon, data_grad)
    
    # Re-classify the perturbed image
    new_prediction = model(perturbed_data).logits
    new_pred_prob = F.softmax(new_prediction, dim=-1).max()
    new_pred_class = new_prediction.argmax(-1).item()
    new_pred_label = model.config.id2label[new_pred_class]

    return perturbed_data, new_pred_prob.item(), new_pred_class, new_pred_label


def find_smallest_epsilon(input_image, target):
    epsilons = [i*0.001 for i in range(1000)]

    for epsilon in epsilons:
        perturbed_data, new_prob, new_id, new_label = perform_attack(input_image, target, epsilon)
        if new_id != target:
            return perturbed_data, new_prob, new_id, new_label, epsilon
    return None

# @st.cache_data
@st.cache(allow_output_mutation=True)
def generate_images(image_id, epsilon=0):
    model, feature_extractor = load_model("ConvNeXt")
    original_image_dict = load_image(image_id)
    image = original_image_dict['image']
    return generate_smoothgrad_mask(
        image, 'ConvNeXt',
        model, feature_extractor, num_samples=10, return_mask=True)