Spaces:
Build error
Build error
| import PIL | |
| from PIL import Image | |
| import numpy as np | |
| from matplotlib import pylab as P | |
| import cv2 | |
| import torch | |
| from torch.utils.data import TensorDataset | |
| from torchvision import transforms | |
| import torch.nn.functional as F | |
| from transformers.image_utils import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD | |
| from torchvex.base import ExplanationMethod | |
| from torchvex.utils.normalization import clamp_quantile | |
| from backend.utils import load_image, load_model | |
| from backend.smooth_grad import generate_smoothgrad_mask | |
| import streamlit as st | |
| IMAGENET_DEFAULT_MEAN = np.asarray(IMAGENET_DEFAULT_MEAN).reshape([1,3,1,1]) | |
| IMAGENET_DEFAULT_STD = np.asarray(IMAGENET_DEFAULT_STD).reshape([1,3,1,1]) | |
| def deprocess_image(image_inputs): | |
| return (image_inputs * IMAGENET_DEFAULT_STD + IMAGENET_DEFAULT_MEAN) * 255 | |
| def feed_forward(input_image): | |
| model, feature_extractor = load_model('ConvNeXt') | |
| inputs = feature_extractor(input_image, do_resize=False, return_tensors="pt")['pixel_values'] | |
| logits = model(inputs).logits | |
| prediction_prob = F.softmax(logits, dim=-1).max() # prediction probability | |
| # prediction class id, start from 1 to 1000 so it needs to +1 in the end | |
| prediction_class = logits.argmax(-1).item() | |
| prediction_label = model.config.id2label[prediction_class] # prediction class label | |
| return prediction_prob, prediction_class, prediction_label | |
| # FGSM attack code | |
| def fgsm_attack(image, epsilon, data_grad): | |
| # Collect the element-wise sign of the data gradient and normalize it | |
| sign_data_grad = torch.gt(data_grad, 0).type(torch.FloatTensor) * 2.0 - 1.0 | |
| perturbed_image = image + epsilon*sign_data_grad | |
| return perturbed_image | |
| # perform attack on the model | |
| def perform_attack(input_image, target, epsilon): | |
| model, feature_extractor = load_model("ConvNeXt") | |
| # preprocess input image | |
| inputs = feature_extractor(input_image, do_resize=False, return_tensors="pt")['pixel_values'] | |
| inputs.requires_grad = True | |
| # predict | |
| logits = model(inputs).logits | |
| prediction_prob = F.softmax(logits, dim=-1).max() | |
| prediction_class = logits.argmax(-1).item() | |
| prediction_label = model.config.id2label[prediction_class] | |
| # Calculate the loss | |
| loss = F.nll_loss(logits, torch.tensor([target])) | |
| # Zero all existing gradients | |
| model.zero_grad() | |
| # Calculate gradients of model in backward pass | |
| loss.backward() | |
| # Collect datagrad | |
| data_grad = inputs.grad.data | |
| # Call FGSM Attack | |
| perturbed_data = fgsm_attack(inputs, epsilon, data_grad) | |
| # Re-classify the perturbed image | |
| new_prediction = model(perturbed_data).logits | |
| new_pred_prob = F.softmax(new_prediction, dim=-1).max() | |
| new_pred_class = new_prediction.argmax(-1).item() | |
| new_pred_label = model.config.id2label[new_pred_class] | |
| return perturbed_data, new_pred_prob.item(), new_pred_class, new_pred_label | |
| def find_smallest_epsilon(input_image, target): | |
| epsilons = [i*0.001 for i in range(1000)] | |
| for epsilon in epsilons: | |
| perturbed_data, new_prob, new_id, new_label = perform_attack(input_image, target, epsilon) | |
| if new_id != target: | |
| return perturbed_data, new_prob, new_id, new_label, epsilon | |
| return None | |
| # @st.cache_data | |
| def generate_images(image_id, epsilon=0): | |
| model, feature_extractor = load_model("ConvNeXt") | |
| original_image_dict = load_image(image_id) | |
| image = original_image_dict['image'] | |
| return generate_smoothgrad_mask( | |
| image, 'ConvNeXt', | |
| model, feature_extractor, num_samples=10, return_mask=True) | |