import gradio as gr import torch import torch.nn as nn from torchvision import transforms from PIL import Image from transformers import ViTForImageClassification, ViTConfig import random import numpy as np import transformers from skimage.metrics import structural_similarity as ssim import requests import os def set_seed(seed): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) transformers.set_seed(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False set_seed(42) device = "cpu" config = ViTConfig.from_pretrained("google/vit-base-patch16-224") config.num_labels = 2 # Binary classification # Download the model file model_url = "https://huggingface.co/spuun/yummy-paws/resolve/main/model.pth" model_path = "best_model.pth" if not os.path.exists(model_path): response = requests.get(model_url) with open(model_path, "wb") as f: f.write(response.content) # Load the trained model model = ViTForImageClassification.from_pretrained( model_path, config=config, ignore_mismatched_sizes=True, # weights_only=False ) model.classifier = nn.Linear(model.config.hidden_size, 2) model.to(device) # Download the reference image reference_image_url = ( "https://huggingface.co/spuun/yummy-paws/resolve/main/images%20(15).jpeg" ) reference_image_path = "reference_image.jpeg" if not os.path.exists(reference_image_path): response = requests.get(reference_image_url) with open(reference_image_path, "wb") as f: f.write(response.content) # Load the reference image for SSIM comparison reference_image = Image.open(reference_image_path) def calculate_ssim(img1, img2): img1_array = np.array(img1) img2_array = np.array(img2) ssim_value = ssim(img1_array, img2_array, channel_axis=2) return ssim_value def predict_and_compare(image): image = image.resize(reference_image.size) ssim_value = calculate_ssim(image, reference_image) transform = transforms.Compose( [ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ] ) image_tensor = transform(image).unsqueeze(0).to(device) model.eval() with torch.no_grad(): output = model(image_tensor).logits probabilities = torch.softmax(output, dim=1)[0] predicted_class_index = torch.argmax(probabilities).item() class_names = ["False", "True"] # Assuming 0 index is False, 1 is True predicted_class = class_names[predicted_class_index] probability = probabilities[predicted_class_index].item() return f"Predicted: {predicted_class}\nProbability: {probability:.4f}\nSSIM with reference: {ssim_value:.4f}" iface = gr.Interface( fn=predict_and_compare, inputs=gr.Image(type="pil"), outputs="text", title="Image Classification and Comparison", description="Upload an image to classify it and compare with a reference image.", ) iface.launch()