Spaces:
Runtime error
Runtime error
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# Importing all necessary libraries ------------------------------------------
|
3 |
+
|
4 |
+
from PIL import Image
|
5 |
+
import gradio as gr
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import pandas as pd
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
from torchvision import models, transforms
|
13 |
+
|
14 |
+
import sys, os, distutils.core
|
15 |
+
|
16 |
+
import detectron2
|
17 |
+
from detectron2 import model_zoo
|
18 |
+
from detectron2.utils.logger import setup_logger
|
19 |
+
from detectron2.engine import DefaultPredictor
|
20 |
+
from detectron2.config import get_cfg
|
21 |
+
|
22 |
+
|
23 |
+
# Model setup ---------------------------------------------------------------
|
24 |
+
|
25 |
+
sys.path.insert(0, os.path.abspath("./detectron2"))
|
26 |
+
setup_logger()
|
27 |
+
|
28 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
29 |
+
|
30 |
+
n_classes = 15
|
31 |
+
criterion = nn.CrossEntropyLoss()
|
32 |
+
|
33 |
+
# Main model
|
34 |
+
model = models.resnet18(pretrained = True)
|
35 |
+
for param in model.parameters():
|
36 |
+
param.require_grad = False
|
37 |
+
n_features = model.fc.in_features
|
38 |
+
model.fc = nn.Linear(n_features, n_classes)
|
39 |
+
model = model.to(device)
|
40 |
+
|
41 |
+
# Viewpoint model
|
42 |
+
model_viewpoint = models.resnet18(pretrained = True)
|
43 |
+
for param in model_viewpoint.parameters():
|
44 |
+
param.require_grad = False
|
45 |
+
n_features = model_viewpoint.fc.in_features
|
46 |
+
model_viewpoint.fc = nn.Linear(n_features, 4)
|
47 |
+
model_viewpoint = model_viewpoint.to(device)
|
48 |
+
|
49 |
+
# Typicality model
|
50 |
+
model_typicality = models.resnet18(pretrained = True)
|
51 |
+
for param in model_typicality.parameters():
|
52 |
+
param.require_grad = False
|
53 |
+
n_features = model_typicality.fc.in_features
|
54 |
+
model_typicality.fc = nn.Linear(n_features, 5)
|
55 |
+
model_typicality = model_typicality.to(device)
|
56 |
+
model_Softmax = nn.Softmax(dim = 1)
|
57 |
+
cos = nn.CosineSimilarity()
|
58 |
+
|
59 |
+
# Transformations to the test set
|
60 |
+
test_transforms = transforms.Compose(
|
61 |
+
[transforms.Resize(size = (224, 224)),
|
62 |
+
transforms.ToTensor(),
|
63 |
+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
64 |
+
]
|
65 |
+
)
|
66 |
+
|
67 |
+
|
68 |
+
# Helper functions ----------------------------------------------------------
|
69 |
+
|
70 |
+
def accuracy(y_pred, y):
|
71 |
+
top_pred = y_pred.argmax(1, keepdim = True)
|
72 |
+
correct = top_pred.eq(y.view_as(top_pred)).sum()
|
73 |
+
acc = correct.float() / y.shape[0]
|
74 |
+
return acc
|
75 |
+
|
76 |
+
activation = {}
|
77 |
+
def getActivation(name):
|
78 |
+
def hook(model_typicality, input, output):
|
79 |
+
activation[name] = output.detach()
|
80 |
+
return hook
|
81 |
+
|
82 |
+
def save_image_locally(image_array_FN, path_FN = "fake.jpg"):
|
83 |
+
image_array_FN = image_array_FN.astype(np.uint8)
|
84 |
+
data = Image.fromarray(image_array_FN)
|
85 |
+
data.save(path_FN)
|
86 |
+
return None
|
87 |
+
|
88 |
+
|
89 |
+
# Prediction ----------------------------------------------------------------
|
90 |
+
|
91 |
+
typicality_dict = {"Convertible": 0, "Hatchback": 1, "MPV": 2, "SUV": 3, "Saloon": 4}
|
92 |
+
classes_dict = {"Convertible_2000": 0, "Convertible_2003": 1, "Convertible_2006": 2, "Convertible_2007": 3, "Convertible_2008": 4, "Convertible_2009": 5, "Convertible_2010": 6, "Convertible_2011": 7, "Convertible_2012": 8, "Convertible_2013": 9, "Convertible_2014": 10, "Convertible_2015": 11, "Convertible_2016": 12, "Convertible_2017": 13, "Hatchback_2000": 14, "Hatchback_2003": 15, "Hatchback_2006": 16, "Hatchback_2007": 17, "Hatchback_2008": 18, "Hatchback_2009": 19, "Hatchback_2010": 20, "Hatchback_2011": 21, "Hatchback_2012": 22, "Hatchback_2013": 23, "Hatchback_2014": 24, "Hatchback_2015": 25, "Hatchback_2016": 26, "Hatchback_2017": 27, "MPV_2000": 28, "MPV_2003": 29, "MPV_2006": 30, "MPV_2007": 31, "MPV_2008": 32, "MPV_2009": 33, "MPV_2010": 34, "MPV_2011": 35, "MPV_2012": 36, "MPV_2013": 37, "MPV_2014": 38, "MPV_2015": 39, "MPV_2016": 40, "MPV_2017": 41, "MPV_2018": 42, "SUV_2000": 43, "SUV_2003": 44, "SUV_2006": 45, "SUV_2007": 46, "SUV_2008": 47, "SUV_2009": 48, "SUV_2010": 49, "SUV_2011": 50, "SUV_2012": 51, "SUV_2013": 52, "SUV_2014": 53, "SUV_2015": 54, "SUV_2016": 55, "SUV_2017": 56, "SUV_2018": 57, "Saloon_2000": 58, "Saloon_2003": 59, "Saloon_2006": 60, "Saloon_2007": 61, "Saloon_2008": 62, "Saloon_2009": 63, "Saloon_2010": 64, "Saloon_2011": 65, "Saloon_2012": 66, "Saloon_2013": 67, "Saloon_2014": 68, "Saloon_2015": 69, "Saloon_2016": 70, "Saloon_2017": 71, "Saloon_2018": 72}
|
93 |
+
years_dict = {"2000": 0, "2003": 1, "2006": 2, "2007": 3, "2008": 4, "2009": 5, "2010": 6, "2011": 7, "2012": 8, "2013": 9, "2014": 10, "2015": 11, "2016": 12, "2017": 13, "2018": 14}
|
94 |
+
|
95 |
+
|
96 |
+
dist = distutils.core.run_setup("./detectron2/setup.py")
|
97 |
+
cfg = get_cfg()
|
98 |
+
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
|
99 |
+
cfg.model.roi_heads.score_thresh_test = 0.5
|
100 |
+
cfg.model.weights = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
|
101 |
+
cfg.model.device = "cpu"
|
102 |
+
predictor = DefaultPredictor(cfg)
|
103 |
+
|
104 |
+
def predict(img_F):
|
105 |
+
|
106 |
+
target_class = 2
|
107 |
+
|
108 |
+
img = Image.fromarray(img_F.astype("uint8"), "RGB")
|
109 |
+
img = np.array(img)
|
110 |
+
|
111 |
+
outputs = predictor(img)
|
112 |
+
masks = outputs["instances"].pred_masks
|
113 |
+
|
114 |
+
pred_classes = outputs["instances"].pred_classes.tolist()
|
115 |
+
pred_boxes = list(outputs["instances"].pred_boxes)
|
116 |
+
|
117 |
+
areas = torch.sum(torch.flatten(masks, start_dim = 1), dim = 1).tolist()
|
118 |
+
total_area = []
|
119 |
+
car_area = []
|
120 |
+
|
121 |
+
for idx in range(len(pred_classes)):
|
122 |
+
if pred_classes[idx] == target_class:
|
123 |
+
total_area.append(areas[idx])
|
124 |
+
car_area.append(idx)
|
125 |
+
|
126 |
+
if len(car_area) == 0:
|
127 |
+
img = Image.open("init.jpg")
|
128 |
+
img = np.array(img)
|
129 |
+
text_output = "Sorry! I am not able to recognize a car in this image. Please upload a new photo!"
|
130 |
+
return text_output, img
|
131 |
+
|
132 |
+
local_idx = total_area.index(max(total_area))
|
133 |
+
global_idx = car_area[local_idx]
|
134 |
+
|
135 |
+
unsq = outputs["instances"].pred_masks[index_global].unsqueeze(-1).to("cpu")
|
136 |
+
mult = torch.tensor(img) * unsq
|
137 |
+
|
138 |
+
unsq = unsq.int()
|
139 |
+
unsq[unsq == 0] = 255
|
140 |
+
unsq[unsq == 1] = 0
|
141 |
+
mult = mult + unsq
|
142 |
+
res = mult.numpy()
|
143 |
+
|
144 |
+
save_image_locally(res, path_FN = "fake.jpg")
|
145 |
+
|
146 |
+
img_pred = Image.open("fake.jpg")
|
147 |
+
img_pred = test_transforms(img_pred)
|
148 |
+
|
149 |
+
model_viewpoint.load_state_dict(torch.load("model_viewpoint.pt", map_location = torch.device("cpu")))
|
150 |
+
model_viewpoint.eval()
|
151 |
+
y_pred = model_viewpoint(img_pred.unsqueeze(0))
|
152 |
+
y_pred = model_Softmax(y_pred)
|
153 |
+
top_pred = y_pred.argmax(1, keepdim = True)
|
154 |
+
|
155 |
+
if top_pred.item() not in [0, 6] :
|
156 |
+
img = Image.open("fake.jpg")
|
157 |
+
img = np.array(img)
|
158 |
+
text_output = "Sorry! I am not able to recognize a frontal view of a car in this image. Please upload a new photo!"
|
159 |
+
return text_output, img
|
160 |
+
|
161 |
+
model.load_state_dict(torch.load("model_modernity.pt", map_location = torch.device("cpu")))
|
162 |
+
model.eval()
|
163 |
+
|
164 |
+
score_t = model(img_pred.unsqueeze(0))
|
165 |
+
score_t = model_Softmax(score_t)
|
166 |
+
model_year = score_curr.argmax(1, keepdim = True).item()
|
167 |
+
score_t = torch.mul(torch.range(0, 14).to(device), torch.reshape(score_t, (-1, ))).sum().item()
|
168 |
+
|
169 |
+
model_typicality.load_state_dict(torch.load("model_typicality.pt", map_location = torch.device("cpu")))
|
170 |
+
model_typicality.eval()
|
171 |
+
model_part = model_typicality(img_pred.unsqueeze(0))
|
172 |
+
model_part = model_Softmax(model_part)
|
173 |
+
model_part = model_part.argmax(1, keepdim = True).item()
|
174 |
+
|
175 |
+
model_avg = pd.DataFrame()
|
176 |
+
h1 = model_typicality.avgpool.register_forward_hook(getActivation("avgpool"))
|
177 |
+
out = model_typicality(img_pred.unsqueeze(0))
|
178 |
+
act_pool_t = activation["avgpool"]
|
179 |
+
h1.remove()
|
180 |
+
|
181 |
+
model_year = list(years_dict.keys())[list(years_dict.values()).index(model_year)]
|
182 |
+
model_part = list(typicality_dict.keys())[list(typicality_dict.values()).index(model_part)]
|
183 |
+
true_idx = classes_dict[model_part + "_" + model_year]
|
184 |
+
|
185 |
+
morph_avg = torch.load("morph.pt")
|
186 |
+
cos_t = cos(morph_avg[true_idx], act_pool_t).item()
|
187 |
+
|
188 |
+
txt = "Modernity score:", str(round(score_t, 2)), "| Typicality score:", str(round(cos_t, 2))
|
189 |
+
|
190 |
+
return txt, res
|
191 |
+
|
192 |
+
|
193 |
+
# Launching the app ---------------------------------------------------------
|
194 |
+
|
195 |
+
interface = gr.Interface(
|
196 |
+
predict,
|
197 |
+
inputs = "image",
|
198 |
+
outputs = ["text", gr.Image(type = "pil")],
|
199 |
+
title = "Let's classify your car!")
|
200 |
+
interface.launch()
|