yumyum / app.py
spuuntries
fix: fix model link
4ba204c
raw
history blame
3.09 kB
import gradio as gr
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
from transformers import ViTForImageClassification, ViTConfig
import random
import numpy as np
import transformers
from skimage.metrics import structural_similarity as ssim
import requests
import os
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
transformers.set_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
set_seed(42)
device = "cpu"
config = ViTConfig.from_pretrained("google/vit-base-patch16-224")
config.num_labels = 2 # Binary classification
# Download the model file
model_url = "https://huggingface.co/spuun/yummy-paws/resolve/main/model.pth"
model_path = "best_model.pth"
if not os.path.exists(model_path):
response = requests.get(model_url)
with open(model_path, "wb") as f:
f.write(response.content)
# Load the trained model
model = ViTForImageClassification.from_pretrained(
model_path,
config=config,
ignore_mismatched_sizes=True, # weights_only=False
)
model.classifier = nn.Linear(model.config.hidden_size, 2)
model.to(device)
# Download the reference image
reference_image_url = (
"https://huggingface.co/spuun/yummy-paws/resolve/main/images%20(15).jpeg"
)
reference_image_path = "reference_image.jpeg"
if not os.path.exists(reference_image_path):
response = requests.get(reference_image_url)
with open(reference_image_path, "wb") as f:
f.write(response.content)
# Load the reference image for SSIM comparison
reference_image = Image.open(reference_image_path)
def calculate_ssim(img1, img2):
img1_array = np.array(img1)
img2_array = np.array(img2)
ssim_value = ssim(img1_array, img2_array, channel_axis=2)
return ssim_value
def predict_and_compare(image):
image = image.resize(reference_image.size)
ssim_value = calculate_ssim(image, reference_image)
transform = transforms.Compose(
[
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
image_tensor = transform(image).unsqueeze(0).to(device)
model.eval()
with torch.no_grad():
output = model(image_tensor).logits
probabilities = torch.softmax(output, dim=1)[0]
predicted_class_index = torch.argmax(probabilities).item()
class_names = ["False", "True"] # Assuming 0 index is False, 1 is True
predicted_class = class_names[predicted_class_index]
probability = probabilities[predicted_class_index].item()
return f"Predicted: {predicted_class}\nProbability: {probability:.4f}\nSSIM with reference: {ssim_value:.4f}"
iface = gr.Interface(
fn=predict_and_compare,
inputs=gr.Image(type="pil"),
outputs="text",
title="Image Classification and Comparison",
description="Upload an image to classify it and compare with a reference image.",
)
iface.launch()