yumyum / app.py
spuuntries
fix: quotes
47654b5
raw
history blame
3.16 kB
import gradio as gr
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
from transformers import ViTForImageClassification, ViTConfig
import random
import numpy as np
import transformers
from skimage.metrics import structural_similarity as ssim
import requests
import os
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
transformers.set_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
set_seed(42)
flag = os.environ["FLAG"] if "FLAG" in os.environ else "fakeflag{placeholder}"
device = "cpu"
config = ViTConfig.from_pretrained("google/vit-base-patch16-224")
config.num_labels = 2 # Binary classification
model_url = "https://huggingface.co/spuun/yummy-paws/resolve/main/model.pth"
model_path = "best_model.pth"
if not os.path.exists(model_path):
response = requests.get(model_url)
with open(model_path, "wb") as f:
f.write(response.content)
model = ViTForImageClassification.from_pretrained(
model_path,
config=config,
ignore_mismatched_sizes=True,
)
model.classifier = nn.Linear(model.config.hidden_size, 2)
model.to(device)
reference_image_url = (
"https://huggingface.co/spuun/yummy-paws/resolve/main/images%20(15).jpeg"
)
reference_image_path = "reference_image.jpeg"
if not os.path.exists(reference_image_path):
response = requests.get(reference_image_url)
with open(reference_image_path, "wb") as f:
f.write(response.content)
reference_image = Image.open(reference_image_path)
def calculate_ssim(img1, img2):
img1_array = np.array(img1)
img2_array = np.array(img2)
ssim_value = ssim(img1_array, img2_array, channel_axis=2)
return ssim_value
def predict_and_compare(image):
image = image.resize(reference_image.size)
ssim_value = calculate_ssim(image, reference_image)
transform = transforms.Compose(
[
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
image_tensor = transform(image).unsqueeze(0).to(device)
model.eval()
with torch.no_grad():
output = model(image_tensor).logits
probabilities = torch.softmax(output, dim=1)[0]
predicted_class_index = torch.argmax(probabilities).item()
class_names = ["False", "True"] # Assuming 0 index is False, 1 is True
predicted_class = class_names[predicted_class_index]
probability = probabilities[predicted_class_index].item()
return f"""{'SUCCESSFULLY AUTHENTICATED!!\nFLAG: '+flag if ssim_value>0.9 and predicted_class == 'True' else 'FAILED TO AUTHENTICATE :('}
=====================
Predicted: {predicted_class}
Probability: {probability:.4f}
SSIM with reference: {ssim_value:.4f}"""
iface = gr.Interface(
fn=predict_and_compare,
inputs=gr.Image(type="pil"),
outputs="text",
title="Image authentication",
description="Upload your image here to be authenticated!",
allow_flagging="never",
)
iface.launch()