rdkulkarni commited on
Commit
0046083
1 Parent(s): c633e82

Create new file

Browse files
Files changed (1) hide show
  1. app-multiple-inp-out.py +105 -0
app-multiple-inp-out.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torchvision
4
+ from PIL import Image
5
+ from torchvision import models
6
+ from torch import nn
7
+ from typing import List
8
+ import json
9
+
10
+ #Read labels file1
11
+ with open('cat_to_name.json','r') as f:
12
+ cat_to_name = json.load(f)
13
+
14
+ #Update last layer of model
15
+ def set_parameter_requires_grad(model, feature_extracting):
16
+ if feature_extracting:
17
+ for param in model.parameters():
18
+ param.requires_grad = False
19
+
20
+ def update_last_layer_pretrained_model(pretrained_model, num_classes, feature_extract):
21
+ set_parameter_requires_grad(pretrained_model, feature_extract)
22
+ if hasattr(pretrained_model, 'fc') and 'resnet' in pretrained_model.__class__.__name__.lower(): #resnet
23
+ num_ftrs = pretrained_model.fc.in_features
24
+ pretrained_model.fc = nn.Linear(num_ftrs, num_classes, bias = True)
25
+ elif hasattr(pretrained_model, 'classifier') and ('alexnet' in pretrained_model.__class__.__name__.lower() or 'vgg' in pretrained_model.__class__.__name__.lower()): #alexNet, vgg
26
+ num_ftrs = pretrained_model.classifier[6].in_features
27
+ pretrained_model.classifier[6] = nn.Linear(num_ftrs, num_classes, bias = True)
28
+ elif hasattr(pretrained_model, 'classifier') and 'squeezenet' in pretrained_model.__class__.__name__.lower(): #squeezenet
29
+ pretrained_model.classifier[1] = nn.Conv2d(512, num_classes, kernel_size=(1,1), stride=(1,1))
30
+ pretrained_model.num_classes = num_classes
31
+ elif hasattr(pretrained_model, 'classifier') and ('efficientnet' in pretrained_model.__class__.__name__.lower() or 'mobilenet' in pretrained_model.__class__.__name__.lower()): #efficientnet, mobilenet
32
+ num_ftrs = pretrained_model.classifier[1].in_features
33
+ pretrained_model.classifier[1] = nn.Linear(num_ftrs, num_classes, bias = True)
34
+ elif hasattr(pretrained_model, 'AuxLogits') and 'inception' in pretrained_model.__class__.__name__.lower(): #inception
35
+ num_ftrs = pretrained_model.AuxLogits.fc.in_features
36
+ pretrained_model.AuxLogits.fc = nn.Linear(num_ftrs, num_classes) #Auxilary net
37
+ num_ftrs = pretrained_model.fc.in_features
38
+ pretrained_model.fc = nn.Linear(num_ftrs,num_classes) #Primary net
39
+ elif hasattr(pretrained_model, 'classifier') and 'densenet' in pretrained_model.__class__.__name__.lower(): #densenet
40
+ num_ftrs = pretrained_model.classifier.in_features
41
+ pretrained_model.classifier = nn.Linear(num_ftrs, num_classes, bias = True)
42
+ elif hasattr(pretrained_model, 'heads') and 'visiontransformer' in pretrained_model.__class__.__name__.lower(): #vit transformer
43
+ num_ftrs = pretrained_model.heads.head.in_features
44
+ pretrained_model.heads.head = nn.Linear(num_ftrs, num_classes, bias = True)
45
+ elif hasattr(pretrained_model, 'head') and 'swin' in pretrained_model.__class__.__name__.lower(): #swin transformer
46
+ num_ftrs = pretrained_model.head.in_features
47
+ pretrained_model.head = nn.Linear(num_ftrs, num_classes, bias = True)
48
+ return pretrained_model
49
+
50
+ #pred_image
51
+ def pred_image(model, image_path, class_names = None, transform=None, device: torch.device = "cuda" if torch.cuda.is_available() else "cpu"):
52
+
53
+ target_image = Image.open(image_path)
54
+ if transform:
55
+ target_image = transform(target_image)
56
+ model.to(device)
57
+ model.eval()
58
+ with torch.inference_mode():
59
+ target_image = target_image.unsqueeze(dim=0)
60
+ target_image_pred = model(target_image.to(device))
61
+
62
+ target_image_pred_probs = torch.softmax(target_image_pred, dim=1)
63
+ ps = target_image_pred_probs.topk(3)
64
+ ps_numpy = ps[0].cpu().numpy()[0]
65
+ idxs = [class_names[i] for i in ps[1].numpy()[0]] if class_names else ps[1].numpy()[0]
66
+
67
+ return (ps_numpy, idxs)
68
+
69
+ def process_input(image_path):
70
+
71
+ #Load Model
72
+ model_name, model_weights, model_path = ('efficientnet_b2','EfficientNet_B2_Weights','flowers_efficientnet_b2_model.pth')
73
+ #model_name, model_weights, model_path = ('alexnet','AlexNet_Weights','flowers_alexnet_model.pth')
74
+ checkpoint = torch.load(model_path, map_location='cpu')
75
+ pretrained_weights = eval(f"models.{model_weights}.DEFAULT")
76
+ auto_transforms = pretrained_weights.transforms()
77
+ #pretrained_model = eval(f"torchvision.models.{model_name}(weights = pretrained_weights)")
78
+ pretrained_model = eval(f"models.{model_name}(pretrained = True)")
79
+ pretrained_model = update_last_layer_pretrained_model(pretrained_model, 102, True)
80
+ pretrained_model.class_to_idx = checkpoint['class_to_idx']
81
+ pretrained_model.class_names = checkpoint['class_names']
82
+ pretrained_model.load_state_dict(checkpoint['state_dict'])
83
+ pretrained_model.to('cpu')
84
+
85
+ #Predict
86
+ #image_path = 'which-flower/80_image_02020.jpg'
87
+ probs, idxs = pred_image(model=pretrained_model, image_path=image_path, class_names=pretrained_model.class_names, transform=auto_transforms)
88
+ names = [cat_to_name[i] for i in idxs]
89
+
90
+ #Display or return to main function
91
+ print({names[i]: float(probs[i]) for i in range(len(names))})
92
+ return {names[i]: float(probs[i]) for i in range(len(names))}, {names[i]: float(probs[i]) for i in range(len(names))}
93
+ #return {names[i]: float(probs[i]) for i in range(len(names))}
94
+
95
+ examples = ['16_image_06670.jpg','33_image_06460.jpg','80_image_02020.jpg', 'Flowers.png','inference_example.png']
96
+ title = "Image Classifier - Species of Flower predicted by different Models"
97
+ description = "Image classifiers to recognize different species of flowers trained on 102 Category Flower Dataset"
98
+ article = article="<p style='text-align: center'><a href='https://www.robots.ox.ac.uk/~vgg/data/flowers/102/index.html' target='_blank'>Source 102 Flower Dataset</a></p>"
99
+ interpretation = 'default'
100
+ enable_queue = True
101
+ iface = gr.Interface(fn=process_input, inputs=gr.inputs.Image(type='filepath'), outputs=[gr.outputs.Label(num_top_classes=3), gr.outputs.Label(num_top_classes=3)], examples = examples, title=title, description=description,article=article,interpretation=interpretation, enable_queue=enable_queue
102
+ )
103
+ iface.launch()
104
+
105
+ #(num_top_classes=3)q