spuuntries commited on
Commit
e8590af
·
1 Parent(s): 1be0680

feat: add app script

Browse files
Files changed (1) hide show
  1. app.py +104 -0
app.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ from torchvision import transforms
5
+ from PIL import Image
6
+ from transformers import ViTForImageClassification, ViTConfig
7
+ import random
8
+ import numpy as np
9
+ import transformers
10
+ from skimage.metrics import structural_similarity as ssim
11
+ import requests
12
+ import os
13
+
14
+
15
+ def set_seed(seed):
16
+ random.seed(seed)
17
+ np.random.seed(seed)
18
+ torch.manual_seed(seed)
19
+ torch.cuda.manual_seed_all(seed)
20
+ transformers.set_seed(seed)
21
+ torch.backends.cudnn.deterministic = True
22
+ torch.backends.cudnn.benchmark = False
23
+
24
+
25
+ set_seed(42)
26
+
27
+ device = "cpu"
28
+ config = ViTConfig.from_pretrained("google/vit-base-patch16-224")
29
+ config.num_labels = 2 # Binary classification
30
+
31
+ # Download the model file
32
+ model_url = "https://huggingface.co/spuun/yummy-paws/resolve/main/best_model.pth"
33
+ model_path = "best_model.pth"
34
+
35
+ if not os.path.exists(model_path):
36
+ response = requests.get(model_url)
37
+ with open(model_path, "wb") as f:
38
+ f.write(response.content)
39
+
40
+ # Load the trained model
41
+ model = ViTForImageClassification.from_pretrained(
42
+ model_path, config=config, ignore_mismatched_sizes=True
43
+ )
44
+ model.classifier = nn.Linear(model.config.hidden_size, 2)
45
+ model.to(device)
46
+
47
+ # Download the reference image
48
+ reference_image_url = (
49
+ "https://huggingface.co/spuun/yummy-paws/resolve/main/images%20(15).jpeg"
50
+ )
51
+ reference_image_path = "reference_image.jpeg"
52
+
53
+ if not os.path.exists(reference_image_path):
54
+ response = requests.get(reference_image_url)
55
+ with open(reference_image_path, "wb") as f:
56
+ f.write(response.content)
57
+
58
+ # Load the reference image for SSIM comparison
59
+ reference_image = Image.open(reference_image_path)
60
+
61
+
62
+ def calculate_ssim(img1, img2):
63
+ img1_array = np.array(img1)
64
+ img2_array = np.array(img2)
65
+ ssim_value = ssim(img1_array, img2_array, channel_axis=2)
66
+ return ssim_value
67
+
68
+
69
+ def predict_and_compare(image):
70
+ image = image.resize(reference_image.size)
71
+ ssim_value = calculate_ssim(image, reference_image)
72
+
73
+ transform = transforms.Compose(
74
+ [
75
+ transforms.Resize((224, 224)),
76
+ transforms.ToTensor(),
77
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
78
+ ]
79
+ )
80
+
81
+ image_tensor = transform(image).unsqueeze(0).to(device)
82
+
83
+ model.eval()
84
+ with torch.no_grad():
85
+ output = model(image_tensor).logits
86
+ probabilities = torch.softmax(output, dim=1)[0]
87
+ predicted_class_index = torch.argmax(probabilities).item()
88
+
89
+ class_names = ["False", "True"] # Assuming 0 index is False, 1 is True
90
+ predicted_class = class_names[predicted_class_index]
91
+ probability = probabilities[predicted_class_index].item()
92
+
93
+ return f"Predicted: {predicted_class}\nProbability: {probability:.4f}\nSSIM with reference: {ssim_value:.4f}"
94
+
95
+
96
+ iface = gr.Interface(
97
+ fn=predict_and_compare,
98
+ inputs=gr.Image(type="pil"),
99
+ outputs="text",
100
+ title="Image Classification and Comparison",
101
+ description="Upload an image to classify it and compare with a reference image.",
102
+ )
103
+
104
+ iface.launch()