lneduchal commited on
Commit
e4f8ef6
·
1 Parent(s): 4a978a3

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +200 -0
app.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Importing all necessary libraries ------------------------------------------
3
+
4
+ from PIL import Image
5
+ import gradio as gr
6
+
7
+ import numpy as np
8
+ import pandas as pd
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ from torchvision import models, transforms
13
+
14
+ import sys, os, distutils.core
15
+
16
+ import detectron2
17
+ from detectron2 import model_zoo
18
+ from detectron2.utils.logger import setup_logger
19
+ from detectron2.engine import DefaultPredictor
20
+ from detectron2.config import get_cfg
21
+
22
+
23
+ # Model setup ---------------------------------------------------------------
24
+
25
+ sys.path.insert(0, os.path.abspath("./detectron2"))
26
+ setup_logger()
27
+
28
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
+
30
+ n_classes = 15
31
+ criterion = nn.CrossEntropyLoss()
32
+
33
+ # Main model
34
+ model = models.resnet18(pretrained = True)
35
+ for param in model.parameters():
36
+ param.require_grad = False
37
+ n_features = model.fc.in_features
38
+ model.fc = nn.Linear(n_features, n_classes)
39
+ model = model.to(device)
40
+
41
+ # Viewpoint model
42
+ model_viewpoint = models.resnet18(pretrained = True)
43
+ for param in model_viewpoint.parameters():
44
+ param.require_grad = False
45
+ n_features = model_viewpoint.fc.in_features
46
+ model_viewpoint.fc = nn.Linear(n_features, 4)
47
+ model_viewpoint = model_viewpoint.to(device)
48
+
49
+ # Typicality model
50
+ model_typicality = models.resnet18(pretrained = True)
51
+ for param in model_typicality.parameters():
52
+ param.require_grad = False
53
+ n_features = model_typicality.fc.in_features
54
+ model_typicality.fc = nn.Linear(n_features, 5)
55
+ model_typicality = model_typicality.to(device)
56
+ model_Softmax = nn.Softmax(dim = 1)
57
+ cos = nn.CosineSimilarity()
58
+
59
+ # Transformations to the test set
60
+ test_transforms = transforms.Compose(
61
+ [transforms.Resize(size = (224, 224)),
62
+ transforms.ToTensor(),
63
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
64
+ ]
65
+ )
66
+
67
+
68
+ # Helper functions ----------------------------------------------------------
69
+
70
+ def accuracy(y_pred, y):
71
+ top_pred = y_pred.argmax(1, keepdim = True)
72
+ correct = top_pred.eq(y.view_as(top_pred)).sum()
73
+ acc = correct.float() / y.shape[0]
74
+ return acc
75
+
76
+ activation = {}
77
+ def getActivation(name):
78
+ def hook(model_typicality, input, output):
79
+ activation[name] = output.detach()
80
+ return hook
81
+
82
+ def save_image_locally(image_array_FN, path_FN = "fake.jpg"):
83
+ image_array_FN = image_array_FN.astype(np.uint8)
84
+ data = Image.fromarray(image_array_FN)
85
+ data.save(path_FN)
86
+ return None
87
+
88
+
89
+ # Prediction ----------------------------------------------------------------
90
+
91
+ typicality_dict = {"Convertible": 0, "Hatchback": 1, "MPV": 2, "SUV": 3, "Saloon": 4}
92
+ classes_dict = {"Convertible_2000": 0, "Convertible_2003": 1, "Convertible_2006": 2, "Convertible_2007": 3, "Convertible_2008": 4, "Convertible_2009": 5, "Convertible_2010": 6, "Convertible_2011": 7, "Convertible_2012": 8, "Convertible_2013": 9, "Convertible_2014": 10, "Convertible_2015": 11, "Convertible_2016": 12, "Convertible_2017": 13, "Hatchback_2000": 14, "Hatchback_2003": 15, "Hatchback_2006": 16, "Hatchback_2007": 17, "Hatchback_2008": 18, "Hatchback_2009": 19, "Hatchback_2010": 20, "Hatchback_2011": 21, "Hatchback_2012": 22, "Hatchback_2013": 23, "Hatchback_2014": 24, "Hatchback_2015": 25, "Hatchback_2016": 26, "Hatchback_2017": 27, "MPV_2000": 28, "MPV_2003": 29, "MPV_2006": 30, "MPV_2007": 31, "MPV_2008": 32, "MPV_2009": 33, "MPV_2010": 34, "MPV_2011": 35, "MPV_2012": 36, "MPV_2013": 37, "MPV_2014": 38, "MPV_2015": 39, "MPV_2016": 40, "MPV_2017": 41, "MPV_2018": 42, "SUV_2000": 43, "SUV_2003": 44, "SUV_2006": 45, "SUV_2007": 46, "SUV_2008": 47, "SUV_2009": 48, "SUV_2010": 49, "SUV_2011": 50, "SUV_2012": 51, "SUV_2013": 52, "SUV_2014": 53, "SUV_2015": 54, "SUV_2016": 55, "SUV_2017": 56, "SUV_2018": 57, "Saloon_2000": 58, "Saloon_2003": 59, "Saloon_2006": 60, "Saloon_2007": 61, "Saloon_2008": 62, "Saloon_2009": 63, "Saloon_2010": 64, "Saloon_2011": 65, "Saloon_2012": 66, "Saloon_2013": 67, "Saloon_2014": 68, "Saloon_2015": 69, "Saloon_2016": 70, "Saloon_2017": 71, "Saloon_2018": 72}
93
+ years_dict = {"2000": 0, "2003": 1, "2006": 2, "2007": 3, "2008": 4, "2009": 5, "2010": 6, "2011": 7, "2012": 8, "2013": 9, "2014": 10, "2015": 11, "2016": 12, "2017": 13, "2018": 14}
94
+
95
+
96
+ dist = distutils.core.run_setup("./detectron2/setup.py")
97
+ cfg = get_cfg()
98
+ cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
99
+ cfg.model.roi_heads.score_thresh_test = 0.5
100
+ cfg.model.weights = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
101
+ cfg.model.device = "cpu"
102
+ predictor = DefaultPredictor(cfg)
103
+
104
+ def predict(img_F):
105
+
106
+ target_class = 2
107
+
108
+ img = Image.fromarray(img_F.astype("uint8"), "RGB")
109
+ img = np.array(img)
110
+
111
+ outputs = predictor(img)
112
+ masks = outputs["instances"].pred_masks
113
+
114
+ pred_classes = outputs["instances"].pred_classes.tolist()
115
+ pred_boxes = list(outputs["instances"].pred_boxes)
116
+
117
+ areas = torch.sum(torch.flatten(masks, start_dim = 1), dim = 1).tolist()
118
+ total_area = []
119
+ car_area = []
120
+
121
+ for idx in range(len(pred_classes)):
122
+ if pred_classes[idx] == target_class:
123
+ total_area.append(areas[idx])
124
+ car_area.append(idx)
125
+
126
+ if len(car_area) == 0:
127
+ img = Image.open("init.jpg")
128
+ img = np.array(img)
129
+ text_output = "Sorry! I am not able to recognize a car in this image. Please upload a new photo!"
130
+ return text_output, img
131
+
132
+ local_idx = total_area.index(max(total_area))
133
+ global_idx = car_area[local_idx]
134
+
135
+ unsq = outputs["instances"].pred_masks[index_global].unsqueeze(-1).to("cpu")
136
+ mult = torch.tensor(img) * unsq
137
+
138
+ unsq = unsq.int()
139
+ unsq[unsq == 0] = 255
140
+ unsq[unsq == 1] = 0
141
+ mult = mult + unsq
142
+ res = mult.numpy()
143
+
144
+ save_image_locally(res, path_FN = "fake.jpg")
145
+
146
+ img_pred = Image.open("fake.jpg")
147
+ img_pred = test_transforms(img_pred)
148
+
149
+ model_viewpoint.load_state_dict(torch.load("model_viewpoint.pt", map_location = torch.device("cpu")))
150
+ model_viewpoint.eval()
151
+ y_pred = model_viewpoint(img_pred.unsqueeze(0))
152
+ y_pred = model_Softmax(y_pred)
153
+ top_pred = y_pred.argmax(1, keepdim = True)
154
+
155
+ if top_pred.item() not in [0, 6] :
156
+ img = Image.open("fake.jpg")
157
+ img = np.array(img)
158
+ text_output = "Sorry! I am not able to recognize a frontal view of a car in this image. Please upload a new photo!"
159
+ return text_output, img
160
+
161
+ model.load_state_dict(torch.load("model_modernity.pt", map_location = torch.device("cpu")))
162
+ model.eval()
163
+
164
+ score_t = model(img_pred.unsqueeze(0))
165
+ score_t = model_Softmax(score_t)
166
+ model_year = score_curr.argmax(1, keepdim = True).item()
167
+ score_t = torch.mul(torch.range(0, 14).to(device), torch.reshape(score_t, (-1, ))).sum().item()
168
+
169
+ model_typicality.load_state_dict(torch.load("model_typicality.pt", map_location = torch.device("cpu")))
170
+ model_typicality.eval()
171
+ model_part = model_typicality(img_pred.unsqueeze(0))
172
+ model_part = model_Softmax(model_part)
173
+ model_part = model_part.argmax(1, keepdim = True).item()
174
+
175
+ model_avg = pd.DataFrame()
176
+ h1 = model_typicality.avgpool.register_forward_hook(getActivation("avgpool"))
177
+ out = model_typicality(img_pred.unsqueeze(0))
178
+ act_pool_t = activation["avgpool"]
179
+ h1.remove()
180
+
181
+ model_year = list(years_dict.keys())[list(years_dict.values()).index(model_year)]
182
+ model_part = list(typicality_dict.keys())[list(typicality_dict.values()).index(model_part)]
183
+ true_idx = classes_dict[model_part + "_" + model_year]
184
+
185
+ morph_avg = torch.load("morph.pt")
186
+ cos_t = cos(morph_avg[true_idx], act_pool_t).item()
187
+
188
+ txt = "Modernity score:", str(round(score_t, 2)), "| Typicality score:", str(round(cos_t, 2))
189
+
190
+ return txt, res
191
+
192
+
193
+ # Launching the app ---------------------------------------------------------
194
+
195
+ interface = gr.Interface(
196
+ predict,
197
+ inputs = "image",
198
+ outputs = ["text", gr.Image(type = "pil")],
199
+ title = "Let's classify your car!")
200
+ interface.launch()