Spaces:
Runtime error
Runtime error
adrianzarbock
commited on
Commit
·
b3a841c
1
Parent(s):
01d5492
Update app.py
Browse files
app.py
CHANGED
@@ -1,9 +1,9 @@
|
|
|
|
1 |
import os
|
2 |
os.system('pip install detectron2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu102/torch1.9/index.html')
|
3 |
os.system('pip install opencv-python')
|
4 |
|
5 |
-
#
|
6 |
-
# Setup detectron2 logger
|
7 |
import torch, detectron2
|
8 |
from detectron2.utils.logger import setup_logger
|
9 |
setup_logger()
|
@@ -11,7 +11,11 @@ setup_logger()
|
|
11 |
# import some common libraries
|
12 |
import numpy as np
|
13 |
import os, json, cv2
|
14 |
-
|
|
|
|
|
|
|
|
|
15 |
|
16 |
# import some common detectron2 utilities
|
17 |
from detectron2 import model_zoo
|
@@ -20,21 +24,12 @@ from detectron2.config import get_cfg
|
|
20 |
from detectron2.utils.visualizer import Visualizer
|
21 |
from detectron2.data import MetadataCatalog, DatasetCatalog
|
22 |
|
23 |
-
import
|
24 |
-
import pandas as pd
|
25 |
-
from PIL import Image
|
26 |
-
from torchvision import transforms
|
27 |
-
from torchvision import models
|
28 |
-
from torch import nn
|
29 |
-
|
30 |
import gradio as gr
|
31 |
|
32 |
-
#
|
33 |
DEVICE = 'cpu'
|
34 |
|
35 |
-
#im = cv2.imread("./input.jpg")
|
36 |
-
#cv2_imshow(im)
|
37 |
-
|
38 |
# load model
|
39 |
model = models.resnet18(pretrained=True)
|
40 |
num_features = model.fc.in_features
|
@@ -43,11 +38,14 @@ model.fc = nn.Linear(num_features, 5)
|
|
43 |
# insert trained paramters
|
44 |
model.load_state_dict(torch.load('model_modernity.pth', map_location=torch.device('cpu')))
|
45 |
|
|
|
46 |
model.eval()
|
47 |
|
|
|
48 |
mean = [0.485, 0.456, 0.406]
|
49 |
std=[0.229, 0.224, 0.225]
|
50 |
|
|
|
51 |
test_transform = transforms.Compose([
|
52 |
transforms.Resize((224,224)),
|
53 |
transforms.ToTensor(),
|
@@ -55,78 +53,116 @@ test_transform = transforms.Compose([
|
|
55 |
std=std)
|
56 |
])
|
57 |
|
|
|
58 |
i1 = gr.inputs.Image(type="numpy", label="Input image")
|
59 |
o1 = gr.outputs.Image(type="pil", label="Cropped image")
|
60 |
o2 = gr.outputs.Textbox(label="Modernity score")
|
61 |
|
|
|
62 |
def modernity(im):
|
|
|
|
|
63 |
cfg = get_cfg()
|
64 |
-
# add project-specific config (e.g., TensorMask) here if you're not running a model in detectron2's core library
|
65 |
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
|
66 |
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 # set threshold for this model
|
67 |
-
# Find a model from detectron2's model zoo. You can use the https://dl.fbaipublicfiles... url as well
|
68 |
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
|
69 |
cfg.MODEL.DEVICE='cpu'
|
70 |
predictor = DefaultPredictor(cfg)
|
71 |
outputs = predictor(im)
|
72 |
|
|
|
73 |
masks = outputs['instances'].pred_masks.to('cpu').numpy()
|
74 |
|
|
|
75 |
obj = []
|
76 |
obj_size = []
|
77 |
|
|
|
78 |
for idx, data in enumerate(outputs['instances'].pred_classes):
|
79 |
num = data.item()
|
80 |
obj.append(MetadataCatalog.get(cfg.DATASETS.TRAIN[0]).thing_classes[num])
|
81 |
obj_size.append(masks[idx].sum())
|
82 |
|
|
|
83 |
if 'car' not in obj:
|
84 |
-
|
|
|
85 |
v = Visualizer(im[:, :, ::-1], MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), scale=1.2)
|
86 |
out = v.draw_instance_predictions(outputs["instances"].to('cpu'))
|
87 |
img = (out.get_image()[:, :, ::-1])
|
|
|
|
|
88 |
out = 'No automobiles were found in the image.'
|
89 |
|
90 |
else:
|
91 |
-
|
|
|
92 |
objects = pd.DataFrame({'obj': obj,
|
93 |
'obj_size': obj_size})
|
|
|
|
|
94 |
item_mask = masks[objects[objects['obj'] == 'car']['obj_size'].idxmax()]
|
95 |
|
|
|
96 |
segmentation = np.where(item_mask == True)
|
97 |
|
|
|
98 |
x_min = int(np.min(segmentation[1]))
|
99 |
x_max = int(np.max(segmentation[1]))
|
100 |
y_min = int(np.min(segmentation[0]))
|
101 |
y_max = int(np.max(segmentation[0]))
|
102 |
|
|
|
103 |
cropped = Image.fromarray(im[y_min:y_max, x_min:x_max, :], mode='RGB')
|
104 |
|
|
|
105 |
mask = Image.fromarray((item_mask * 255).astype('uint8'))
|
106 |
-
|
|
|
107 |
cropped_mask = mask.crop((x_min, y_min, x_max, y_max))
|
108 |
|
|
|
109 |
background = Image.new(mode='RGB', size=cropped_mask.size, color='white')
|
|
|
|
|
110 |
paste_position = (0,0)
|
111 |
|
|
|
112 |
new_fg_image = Image.new('RGB', background.size)
|
113 |
new_fg_image.paste(cropped, paste_position)
|
|
|
|
|
|
|
114 |
|
115 |
-
|
116 |
-
img = composite
|
117 |
-
|
118 |
img_t = test_transform(img).to(DEVICE)
|
|
|
|
|
119 |
out = model(img_t[None, :])
|
|
|
|
|
120 |
softmax = nn.Softmax(dim=1)
|
121 |
out = softmax(out)
|
|
|
|
|
122 |
label_classes=torch.tensor([0,1,2,3,4]).to(DEVICE)
|
|
|
|
|
123 |
out = round((label_classes * out).sum(axis=1).item(),1)
|
124 |
|
125 |
return img, out
|
126 |
|
|
|
127 |
title = 'Design Modernity of Automobiles'
|
|
|
|
|
128 |
description = "Demo for design modernity of automobiles. To use it, simply upload your image, or click one of the examples to load them."
|
|
|
|
|
129 |
examples = [['input.jpg'],['input1.jpg']]
|
|
|
|
|
130 |
interface = gr.Interface(modernity,inputs=i1, outputs=[o1, o2], title=title, description=description, examples=examples, cache_examples=False)
|
131 |
|
|
|
132 |
interface.launch()
|
|
|
1 |
+
# general setup
|
2 |
import os
|
3 |
os.system('pip install detectron2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu102/torch1.9/index.html')
|
4 |
os.system('pip install opencv-python')
|
5 |
|
6 |
+
# setup detectron2 logger
|
|
|
7 |
import torch, detectron2
|
8 |
from detectron2.utils.logger import setup_logger
|
9 |
setup_logger()
|
|
|
11 |
# import some common libraries
|
12 |
import numpy as np
|
13 |
import os, json, cv2
|
14 |
+
import pandas as pd
|
15 |
+
from PIL import Image
|
16 |
+
from torchvision import transforms
|
17 |
+
from torchvision import models
|
18 |
+
from torch import nn
|
19 |
|
20 |
# import some common detectron2 utilities
|
21 |
from detectron2 import model_zoo
|
|
|
24 |
from detectron2.utils.visualizer import Visualizer
|
25 |
from detectron2.data import MetadataCatalog, DatasetCatalog
|
26 |
|
27 |
+
# import gradio
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
import gradio as gr
|
29 |
|
30 |
+
# set device
|
31 |
DEVICE = 'cpu'
|
32 |
|
|
|
|
|
|
|
33 |
# load model
|
34 |
model = models.resnet18(pretrained=True)
|
35 |
num_features = model.fc.in_features
|
|
|
38 |
# insert trained paramters
|
39 |
model.load_state_dict(torch.load('model_modernity.pth', map_location=torch.device('cpu')))
|
40 |
|
41 |
+
# enable model eval
|
42 |
model.eval()
|
43 |
|
44 |
+
# define mean and std of resent training data
|
45 |
mean = [0.485, 0.456, 0.406]
|
46 |
std=[0.229, 0.224, 0.225]
|
47 |
|
48 |
+
# define transforms
|
49 |
test_transform = transforms.Compose([
|
50 |
transforms.Resize((224,224)),
|
51 |
transforms.ToTensor(),
|
|
|
53 |
std=std)
|
54 |
])
|
55 |
|
56 |
+
# define input and outputs
|
57 |
i1 = gr.inputs.Image(type="numpy", label="Input image")
|
58 |
o1 = gr.outputs.Image(type="pil", label="Cropped image")
|
59 |
o2 = gr.outputs.Textbox(label="Modernity score")
|
60 |
|
61 |
+
# define function to be called by gradio interface
|
62 |
def modernity(im):
|
63 |
+
|
64 |
+
# create detectron2 config and detectron2 DefaultPredictor to run inference on image
|
65 |
cfg = get_cfg()
|
|
|
66 |
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
|
67 |
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 # set threshold for this model
|
|
|
68 |
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
|
69 |
cfg.MODEL.DEVICE='cpu'
|
70 |
predictor = DefaultPredictor(cfg)
|
71 |
outputs = predictor(im)
|
72 |
|
73 |
+
# get all masks of input image
|
74 |
masks = outputs['instances'].pred_masks.to('cpu').numpy()
|
75 |
|
76 |
+
# create empty lists for objects names and object sizes
|
77 |
obj = []
|
78 |
obj_size = []
|
79 |
|
80 |
+
# iterate over all detected objects in input image to obtain object names and object sizes
|
81 |
for idx, data in enumerate(outputs['instances'].pred_classes):
|
82 |
num = data.item()
|
83 |
obj.append(MetadataCatalog.get(cfg.DATASETS.TRAIN[0]).thing_classes[num])
|
84 |
obj_size.append(masks[idx].sum())
|
85 |
|
86 |
+
# define output if there is no automobile detected
|
87 |
if 'car' not in obj:
|
88 |
+
|
89 |
+
# return image with all detected objects highlighted
|
90 |
v = Visualizer(im[:, :, ::-1], MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), scale=1.2)
|
91 |
out = v.draw_instance_predictions(outputs["instances"].to('cpu'))
|
92 |
img = (out.get_image()[:, :, ::-1])
|
93 |
+
|
94 |
+
# return message
|
95 |
out = 'No automobiles were found in the image.'
|
96 |
|
97 |
else:
|
98 |
+
|
99 |
+
# create data frame containing all object names and sizes
|
100 |
objects = pd.DataFrame({'obj': obj,
|
101 |
'obj_size': obj_size})
|
102 |
+
|
103 |
+
# get mask of the largest object that is labeled as car
|
104 |
item_mask = masks[objects[objects['obj'] == 'car']['obj_size'].idxmax()]
|
105 |
|
106 |
+
# create segmentation
|
107 |
segmentation = np.where(item_mask == True)
|
108 |
|
109 |
+
# get x and y boundaries
|
110 |
x_min = int(np.min(segmentation[1]))
|
111 |
x_max = int(np.max(segmentation[1]))
|
112 |
y_min = int(np.min(segmentation[0]))
|
113 |
y_max = int(np.max(segmentation[0]))
|
114 |
|
115 |
+
# create cropped image
|
116 |
cropped = Image.fromarray(im[y_min:y_max, x_min:x_max, :], mode='RGB')
|
117 |
|
118 |
+
# create mask
|
119 |
mask = Image.fromarray((item_mask * 255).astype('uint8'))
|
120 |
+
|
121 |
+
# create cropped mask
|
122 |
cropped_mask = mask.crop((x_min, y_min, x_max, y_max))
|
123 |
|
124 |
+
# create background
|
125 |
background = Image.new(mode='RGB', size=cropped_mask.size, color='white')
|
126 |
+
|
127 |
+
# define paste position
|
128 |
paste_position = (0,0)
|
129 |
|
130 |
+
# create foreground image
|
131 |
new_fg_image = Image.new('RGB', background.size)
|
132 |
new_fg_image.paste(cropped, paste_position)
|
133 |
+
|
134 |
+
# composite final image
|
135 |
+
img = Image.composite(new_fg_image, background, cropped_mask)
|
136 |
|
137 |
+
# apply previously defined transformations
|
|
|
|
|
138 |
img_t = test_transform(img).to(DEVICE)
|
139 |
+
|
140 |
+
# feed transformed image to the model
|
141 |
out = model(img_t[None, :])
|
142 |
+
|
143 |
+
# apply softmax
|
144 |
softmax = nn.Softmax(dim=1)
|
145 |
out = softmax(out)
|
146 |
+
|
147 |
+
# get label classes
|
148 |
label_classes=torch.tensor([0,1,2,3,4]).to(DEVICE)
|
149 |
+
|
150 |
+
# compute modernity score
|
151 |
out = round((label_classes * out).sum(axis=1).item(),1)
|
152 |
|
153 |
return img, out
|
154 |
|
155 |
+
# set interface title
|
156 |
title = 'Design Modernity of Automobiles'
|
157 |
+
|
158 |
+
# set interface description
|
159 |
description = "Demo for design modernity of automobiles. To use it, simply upload your image, or click one of the examples to load them."
|
160 |
+
|
161 |
+
# include example images
|
162 |
examples = [['input.jpg'],['input1.jpg']]
|
163 |
+
|
164 |
+
# define interface
|
165 |
interface = gr.Interface(modernity,inputs=i1, outputs=[o1, o2], title=title, description=description, examples=examples, cache_examples=False)
|
166 |
|
167 |
+
# launch interface
|
168 |
interface.launch()
|