rdkulkarni commited on
Commit
5e5d09e
1 Parent(s): 4a1f9c5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -21
app.py CHANGED
@@ -6,6 +6,7 @@ from torchvision import models
6
  from torch import nn
7
  from typing import List
8
  import json
 
9
 
10
  #Read labels file1
11
  with open('cat_to_name.json','r') as f:
@@ -69,37 +70,54 @@ def pred_image(model, image_path, class_names = None, transform=None, device: to
69
  def process_input(image_path):
70
 
71
  #Load Model
72
- model_name, model_weights, model_path = ('efficientnet_b2','EfficientNet_B2_Weights','flowers_efficientnet_b2_model.pth')
 
 
 
 
 
 
 
 
 
 
 
 
73
  #model_name, model_weights, model_path = ('alexnet','AlexNet_Weights','flowers_alexnet_model.pth')
74
- checkpoint = torch.load(model_path, map_location='cpu')
75
- pretrained_weights = eval(f"models.{model_weights}.DEFAULT")
76
- auto_transforms = pretrained_weights.transforms()
77
- #pretrained_model = eval(f"torchvision.models.{model_name}(weights = pretrained_weights)")
78
- pretrained_model = eval(f"models.{model_name}(pretrained = True)")
79
- pretrained_model = update_last_layer_pretrained_model(pretrained_model, 102, True)
80
- pretrained_model.class_to_idx = checkpoint['class_to_idx']
81
- pretrained_model.class_names = checkpoint['class_names']
82
- pretrained_model.load_state_dict(checkpoint['state_dict'])
83
- pretrained_model.to('cpu')
 
 
 
 
 
 
 
84
 
85
- #Predict
86
- #image_path = 'which-flower/80_image_02020.jpg'
87
- probs, idxs = pred_image(model=pretrained_model, image_path=image_path, class_names=pretrained_model.class_names, transform=auto_transforms)
88
- names = [cat_to_name[i] for i in idxs]
89
-
90
  #Display or return to main function
91
- print({names[i]: float(probs[i]) for i in range(len(names))})
92
- return {names[i]: float(probs[i]) for i in range(len(names))}, {names[i]: float(probs[i]) for i in range(len(names))}
93
  #return {names[i]: float(probs[i]) for i in range(len(names))}
94
-
 
 
 
95
  examples = ['16_image_06670.jpg','33_image_06460.jpg','80_image_02020.jpg', 'Flowers.png','inference_example.png']
96
  title = "Image Classifier - Species of Flower predicted by different Models"
97
  description = "Image classifiers to recognize different species of flowers trained on 102 Category Flower Dataset"
98
  article = article="<p style='text-align: center'><a href='https://www.robots.ox.ac.uk/~vgg/data/flowers/102/index.html' target='_blank'>Source 102 Flower Dataset</a></p>"
99
  interpretation = 'default'
100
  enable_queue = True
101
- iface = gr.Interface(fn=process_input, inputs=gr.inputs.Image(type='filepath'), outputs=[gr.outputs.Label(num_top_classes=3), gr.outputs.Label(num_top_classes=3)], examples = examples, title=title, description=description,article=article,interpretation=interpretation, enable_queue=enable_queue
102
- )
103
  iface.launch()
104
 
105
  #(num_top_classes=3)q
 
6
  from torch import nn
7
  from typing import List
8
  import json
9
+ import pandas as pd
10
 
11
  #Read labels file1
12
  with open('cat_to_name.json','r') as f:
 
70
  def process_input(image_path):
71
 
72
  #Load Model
73
+ list_of_models_and_weights = [
74
+ ('efficientnet_b2','EfficientNet_B2_Weights','flowers_efficientnet_b2_model.pth'),
75
+ ('alexnet','AlexNet_Weights','flowers_alexnet_model.pth')
76
+ #('mobilenet_v2','MobileNet_V2_Weights','flowers_mobilenet_v2_model.pth'),
77
+ #('densenet121','DenseNet121_Weights','flowers_densenet121_model.pth'),
78
+ #('inception_v3','Inception_V3_Weights','flowers_inception_v3_model.pth'),
79
+ #('squeezenet1_1','SqueezeNet1_1_Weights','flowers_squeezenet1_1_model.pth'),
80
+ #('vgg16','VGG16_Weights','flowers_vgg16_model.pth'),
81
+ #('resnet18','ResNet18_Weights','flowers_resnet18_model.pth'),
82
+ #('swin_b','Swin_B_Weights','flowers_swin_b_model.pth'),
83
+ #('vit_b_16', 'ViT_B_16_Weights','flowers_vit_b_16_model.pth')
84
+ ]
85
+ #model_name, model_weights, model_path = ('efficientnet_b2','EfficientNet_B2_Weights','flowers_efficientnet_b2_model.pth')
86
  #model_name, model_weights, model_path = ('alexnet','AlexNet_Weights','flowers_alexnet_model.pth')
87
+ list_of_outputs = []
88
+ for model_name, model_weights, model_path in list_of_models_and_weights:
89
+ checkpoint = torch.load(model_path, map_location='cpu')
90
+ pretrained_weights = eval(f"models.{model_weights}.DEFAULT")
91
+ auto_transforms = pretrained_weights.transforms()
92
+ #pretrained_model = eval(f"torchvision.models.{model_name}(weights = pretrained_weights)")
93
+ pretrained_model = eval(f"models.{model_name}(pretrained = True)")
94
+ pretrained_model = update_last_layer_pretrained_model(pretrained_model, 102, True)
95
+ pretrained_model.class_to_idx = checkpoint['class_to_idx']
96
+ pretrained_model.class_names = checkpoint['class_names']
97
+ pretrained_model.load_state_dict(checkpoint['state_dict'])
98
+ pretrained_model.to('cpu')
99
+
100
+ #Predict
101
+ #image_path = 'which-flower/80_image_02020.jpg'
102
+ probs, idxs = pred_image(model=pretrained_model, image_path=image_path, class_names=pretrained_model.class_names, transform=auto_transforms)
103
+ names = [cat_to_name[i] for i in idxs]
104
 
105
+ list_of_outputs.append({"Prediction" : names[0], "Probability" : probs[0]})
106
+
 
 
 
107
  #Display or return to main function
108
+ #print({names[i]: float(probs[i]) for i in range(len(names))})
 
109
  #return {names[i]: float(probs[i]) for i in range(len(names))}
110
+ #oldreturn {names[i]: float(probs[i]) for i in range(len(names))}
111
+ print(pd.DataFrame(list_of_outputs))
112
+ return pd.DataFrame(list_of_outputs)
113
+
114
  examples = ['16_image_06670.jpg','33_image_06460.jpg','80_image_02020.jpg', 'Flowers.png','inference_example.png']
115
  title = "Image Classifier - Species of Flower predicted by different Models"
116
  description = "Image classifiers to recognize different species of flowers trained on 102 Category Flower Dataset"
117
  article = article="<p style='text-align: center'><a href='https://www.robots.ox.ac.uk/~vgg/data/flowers/102/index.html' target='_blank'>Source 102 Flower Dataset</a></p>"
118
  interpretation = 'default'
119
  enable_queue = True
120
+ iface = gr.Interface(fn=process_input, inputs=gr.inputs.Image(type='filepath'), outputs="dataframe", examples = examples, title=title, description=description,article=article,interpretation=interpretation, enable_queue=enable_queue)
 
121
  iface.launch()
122
 
123
  #(num_top_classes=3)q