|
import gradio as gr |
|
import torch |
|
import torch.nn as nn |
|
from torchvision import transforms |
|
from PIL import Image |
|
from transformers import ViTForImageClassification, ViTConfig |
|
import random |
|
import numpy as np |
|
import transformers |
|
from skimage.metrics import structural_similarity as ssim |
|
import requests |
|
import os |
|
|
|
|
|
def set_seed(seed): |
|
random.seed(seed) |
|
np.random.seed(seed) |
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed_all(seed) |
|
transformers.set_seed(seed) |
|
torch.backends.cudnn.deterministic = True |
|
torch.backends.cudnn.benchmark = False |
|
|
|
|
|
set_seed(42) |
|
|
|
flag = os.environ["FLAG"] if "FLAG" in os.environ else "fakeflag{placeholder}" |
|
device = "cpu" |
|
config = ViTConfig.from_pretrained("google/vit-base-patch16-224") |
|
config.num_labels = 2 |
|
|
|
model_url = "https://huggingface.co/spuun/yummy-paws/resolve/main/model.pth" |
|
model_path = "best_model.pth" |
|
|
|
if not os.path.exists(model_path): |
|
response = requests.get(model_url) |
|
with open(model_path, "wb") as f: |
|
f.write(response.content) |
|
|
|
model = ViTForImageClassification.from_pretrained( |
|
model_path, |
|
config=config, |
|
ignore_mismatched_sizes=True, |
|
) |
|
model.classifier = nn.Linear(model.config.hidden_size, 2) |
|
model.to(device) |
|
|
|
reference_image_url = ( |
|
"https://huggingface.co/spuun/yummy-paws/resolve/main/images%20(15).jpeg" |
|
) |
|
reference_image_path = "reference_image.jpeg" |
|
|
|
if not os.path.exists(reference_image_path): |
|
response = requests.get(reference_image_url) |
|
with open(reference_image_path, "wb") as f: |
|
f.write(response.content) |
|
|
|
reference_image = Image.open(reference_image_path) |
|
|
|
|
|
def calculate_ssim(img1, img2): |
|
img1_array = np.array(img1) |
|
img2_array = np.array(img2) |
|
ssim_value = ssim(img1_array, img2_array, channel_axis=2) |
|
return ssim_value |
|
|
|
|
|
def predict_and_compare(image): |
|
image = image.resize(reference_image.size) |
|
ssim_value = calculate_ssim(image, reference_image) |
|
|
|
transform = transforms.Compose( |
|
[ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), |
|
] |
|
) |
|
|
|
image_tensor = transform(image).unsqueeze(0).to(device) |
|
|
|
model.eval() |
|
with torch.no_grad(): |
|
output = model(image_tensor).logits |
|
probabilities = torch.softmax(output, dim=1)[0] |
|
predicted_class_index = torch.argmax(probabilities).item() |
|
|
|
class_names = ["False", "True"] |
|
predicted_class = class_names[predicted_class_index] |
|
probability = probabilities[predicted_class_index].item() |
|
|
|
return f"""{'SUCCESSFULLY AUTHENTICATED!!\nFLAG: '+flag if ssim_value>0.9 and predicted_class == 'True' else 'FAILED TO AUTHENTICATE :('} |
|
===================== |
|
|
|
Predicted: {predicted_class} |
|
Probability: {probability:.4f} |
|
SSIM with reference: {ssim_value:.4f}""" |
|
|
|
|
|
iface = gr.Interface( |
|
fn=predict_and_compare, |
|
inputs=gr.Image(type="pil"), |
|
outputs="text", |
|
title="Image authentication", |
|
description="Upload your image here to be authenticated!", |
|
allow_flagging="never", |
|
) |
|
|
|
iface.launch() |
|
|