Spaces:
Paused
Paused
rdkulkarni
commited on
Commit
•
5e5d09e
1
Parent(s):
4a1f9c5
Update app.py
Browse files
app.py
CHANGED
@@ -6,6 +6,7 @@ from torchvision import models
|
|
6 |
from torch import nn
|
7 |
from typing import List
|
8 |
import json
|
|
|
9 |
|
10 |
#Read labels file1
|
11 |
with open('cat_to_name.json','r') as f:
|
@@ -69,37 +70,54 @@ def pred_image(model, image_path, class_names = None, transform=None, device: to
|
|
69 |
def process_input(image_path):
|
70 |
|
71 |
#Load Model
|
72 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
#model_name, model_weights, model_path = ('alexnet','AlexNet_Weights','flowers_alexnet_model.pth')
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
|
85 |
-
|
86 |
-
|
87 |
-
probs, idxs = pred_image(model=pretrained_model, image_path=image_path, class_names=pretrained_model.class_names, transform=auto_transforms)
|
88 |
-
names = [cat_to_name[i] for i in idxs]
|
89 |
-
|
90 |
#Display or return to main function
|
91 |
-
print({names[i]: float(probs[i]) for i in range(len(names))})
|
92 |
-
return {names[i]: float(probs[i]) for i in range(len(names))}, {names[i]: float(probs[i]) for i in range(len(names))}
|
93 |
#return {names[i]: float(probs[i]) for i in range(len(names))}
|
94 |
-
|
|
|
|
|
|
|
95 |
examples = ['16_image_06670.jpg','33_image_06460.jpg','80_image_02020.jpg', 'Flowers.png','inference_example.png']
|
96 |
title = "Image Classifier - Species of Flower predicted by different Models"
|
97 |
description = "Image classifiers to recognize different species of flowers trained on 102 Category Flower Dataset"
|
98 |
article = article="<p style='text-align: center'><a href='https://www.robots.ox.ac.uk/~vgg/data/flowers/102/index.html' target='_blank'>Source 102 Flower Dataset</a></p>"
|
99 |
interpretation = 'default'
|
100 |
enable_queue = True
|
101 |
-
iface = gr.Interface(fn=process_input, inputs=gr.inputs.Image(type='filepath'), outputs=
|
102 |
-
)
|
103 |
iface.launch()
|
104 |
|
105 |
#(num_top_classes=3)q
|
|
|
6 |
from torch import nn
|
7 |
from typing import List
|
8 |
import json
|
9 |
+
import pandas as pd
|
10 |
|
11 |
#Read labels file1
|
12 |
with open('cat_to_name.json','r') as f:
|
|
|
70 |
def process_input(image_path):
|
71 |
|
72 |
#Load Model
|
73 |
+
list_of_models_and_weights = [
|
74 |
+
('efficientnet_b2','EfficientNet_B2_Weights','flowers_efficientnet_b2_model.pth'),
|
75 |
+
('alexnet','AlexNet_Weights','flowers_alexnet_model.pth')
|
76 |
+
#('mobilenet_v2','MobileNet_V2_Weights','flowers_mobilenet_v2_model.pth'),
|
77 |
+
#('densenet121','DenseNet121_Weights','flowers_densenet121_model.pth'),
|
78 |
+
#('inception_v3','Inception_V3_Weights','flowers_inception_v3_model.pth'),
|
79 |
+
#('squeezenet1_1','SqueezeNet1_1_Weights','flowers_squeezenet1_1_model.pth'),
|
80 |
+
#('vgg16','VGG16_Weights','flowers_vgg16_model.pth'),
|
81 |
+
#('resnet18','ResNet18_Weights','flowers_resnet18_model.pth'),
|
82 |
+
#('swin_b','Swin_B_Weights','flowers_swin_b_model.pth'),
|
83 |
+
#('vit_b_16', 'ViT_B_16_Weights','flowers_vit_b_16_model.pth')
|
84 |
+
]
|
85 |
+
#model_name, model_weights, model_path = ('efficientnet_b2','EfficientNet_B2_Weights','flowers_efficientnet_b2_model.pth')
|
86 |
#model_name, model_weights, model_path = ('alexnet','AlexNet_Weights','flowers_alexnet_model.pth')
|
87 |
+
list_of_outputs = []
|
88 |
+
for model_name, model_weights, model_path in list_of_models_and_weights:
|
89 |
+
checkpoint = torch.load(model_path, map_location='cpu')
|
90 |
+
pretrained_weights = eval(f"models.{model_weights}.DEFAULT")
|
91 |
+
auto_transforms = pretrained_weights.transforms()
|
92 |
+
#pretrained_model = eval(f"torchvision.models.{model_name}(weights = pretrained_weights)")
|
93 |
+
pretrained_model = eval(f"models.{model_name}(pretrained = True)")
|
94 |
+
pretrained_model = update_last_layer_pretrained_model(pretrained_model, 102, True)
|
95 |
+
pretrained_model.class_to_idx = checkpoint['class_to_idx']
|
96 |
+
pretrained_model.class_names = checkpoint['class_names']
|
97 |
+
pretrained_model.load_state_dict(checkpoint['state_dict'])
|
98 |
+
pretrained_model.to('cpu')
|
99 |
+
|
100 |
+
#Predict
|
101 |
+
#image_path = 'which-flower/80_image_02020.jpg'
|
102 |
+
probs, idxs = pred_image(model=pretrained_model, image_path=image_path, class_names=pretrained_model.class_names, transform=auto_transforms)
|
103 |
+
names = [cat_to_name[i] for i in idxs]
|
104 |
|
105 |
+
list_of_outputs.append({"Prediction" : names[0], "Probability" : probs[0]})
|
106 |
+
|
|
|
|
|
|
|
107 |
#Display or return to main function
|
108 |
+
#print({names[i]: float(probs[i]) for i in range(len(names))})
|
|
|
109 |
#return {names[i]: float(probs[i]) for i in range(len(names))}
|
110 |
+
#oldreturn {names[i]: float(probs[i]) for i in range(len(names))}
|
111 |
+
print(pd.DataFrame(list_of_outputs))
|
112 |
+
return pd.DataFrame(list_of_outputs)
|
113 |
+
|
114 |
examples = ['16_image_06670.jpg','33_image_06460.jpg','80_image_02020.jpg', 'Flowers.png','inference_example.png']
|
115 |
title = "Image Classifier - Species of Flower predicted by different Models"
|
116 |
description = "Image classifiers to recognize different species of flowers trained on 102 Category Flower Dataset"
|
117 |
article = article="<p style='text-align: center'><a href='https://www.robots.ox.ac.uk/~vgg/data/flowers/102/index.html' target='_blank'>Source 102 Flower Dataset</a></p>"
|
118 |
interpretation = 'default'
|
119 |
enable_queue = True
|
120 |
+
iface = gr.Interface(fn=process_input, inputs=gr.inputs.Image(type='filepath'), outputs="dataframe", examples = examples, title=title, description=description,article=article,interpretation=interpretation, enable_queue=enable_queue)
|
|
|
121 |
iface.launch()
|
122 |
|
123 |
#(num_top_classes=3)q
|