adrianzarbock commited on
Commit
b3a841c
·
1 Parent(s): 01d5492

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -22
app.py CHANGED
@@ -1,9 +1,9 @@
 
1
  import os
2
  os.system('pip install detectron2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu102/torch1.9/index.html')
3
  os.system('pip install opencv-python')
4
 
5
- # Some basic setup:
6
- # Setup detectron2 logger
7
  import torch, detectron2
8
  from detectron2.utils.logger import setup_logger
9
  setup_logger()
@@ -11,7 +11,11 @@ setup_logger()
11
  # import some common libraries
12
  import numpy as np
13
  import os, json, cv2
14
- #from google.colab.patches import cv2_imshow
 
 
 
 
15
 
16
  # import some common detectron2 utilities
17
  from detectron2 import model_zoo
@@ -20,21 +24,12 @@ from detectron2.config import get_cfg
20
  from detectron2.utils.visualizer import Visualizer
21
  from detectron2.data import MetadataCatalog, DatasetCatalog
22
 
23
- import matplotlib.pyplot as plt
24
- import pandas as pd
25
- from PIL import Image
26
- from torchvision import transforms
27
- from torchvision import models
28
- from torch import nn
29
-
30
  import gradio as gr
31
 
32
- # enable computation on GPU if available
33
  DEVICE = 'cpu'
34
 
35
- #im = cv2.imread("./input.jpg")
36
- #cv2_imshow(im)
37
-
38
  # load model
39
  model = models.resnet18(pretrained=True)
40
  num_features = model.fc.in_features
@@ -43,11 +38,14 @@ model.fc = nn.Linear(num_features, 5)
43
  # insert trained paramters
44
  model.load_state_dict(torch.load('model_modernity.pth', map_location=torch.device('cpu')))
45
 
 
46
  model.eval()
47
 
 
48
  mean = [0.485, 0.456, 0.406]
49
  std=[0.229, 0.224, 0.225]
50
 
 
51
  test_transform = transforms.Compose([
52
  transforms.Resize((224,224)),
53
  transforms.ToTensor(),
@@ -55,78 +53,116 @@ test_transform = transforms.Compose([
55
  std=std)
56
  ])
57
 
 
58
  i1 = gr.inputs.Image(type="numpy", label="Input image")
59
  o1 = gr.outputs.Image(type="pil", label="Cropped image")
60
  o2 = gr.outputs.Textbox(label="Modernity score")
61
 
 
62
  def modernity(im):
 
 
63
  cfg = get_cfg()
64
- # add project-specific config (e.g., TensorMask) here if you're not running a model in detectron2's core library
65
  cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
66
  cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 # set threshold for this model
67
- # Find a model from detectron2's model zoo. You can use the https://dl.fbaipublicfiles... url as well
68
  cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
69
  cfg.MODEL.DEVICE='cpu'
70
  predictor = DefaultPredictor(cfg)
71
  outputs = predictor(im)
72
 
 
73
  masks = outputs['instances'].pred_masks.to('cpu').numpy()
74
 
 
75
  obj = []
76
  obj_size = []
77
 
 
78
  for idx, data in enumerate(outputs['instances'].pred_classes):
79
  num = data.item()
80
  obj.append(MetadataCatalog.get(cfg.DATASETS.TRAIN[0]).thing_classes[num])
81
  obj_size.append(masks[idx].sum())
82
 
 
83
  if 'car' not in obj:
84
-
 
85
  v = Visualizer(im[:, :, ::-1], MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), scale=1.2)
86
  out = v.draw_instance_predictions(outputs["instances"].to('cpu'))
87
  img = (out.get_image()[:, :, ::-1])
 
 
88
  out = 'No automobiles were found in the image.'
89
 
90
  else:
91
-
 
92
  objects = pd.DataFrame({'obj': obj,
93
  'obj_size': obj_size})
 
 
94
  item_mask = masks[objects[objects['obj'] == 'car']['obj_size'].idxmax()]
95
 
 
96
  segmentation = np.where(item_mask == True)
97
 
 
98
  x_min = int(np.min(segmentation[1]))
99
  x_max = int(np.max(segmentation[1]))
100
  y_min = int(np.min(segmentation[0]))
101
  y_max = int(np.max(segmentation[0]))
102
 
 
103
  cropped = Image.fromarray(im[y_min:y_max, x_min:x_max, :], mode='RGB')
104
 
 
105
  mask = Image.fromarray((item_mask * 255).astype('uint8'))
106
-
 
107
  cropped_mask = mask.crop((x_min, y_min, x_max, y_max))
108
 
 
109
  background = Image.new(mode='RGB', size=cropped_mask.size, color='white')
 
 
110
  paste_position = (0,0)
111
 
 
112
  new_fg_image = Image.new('RGB', background.size)
113
  new_fg_image.paste(cropped, paste_position)
 
 
 
114
 
115
- composite = Image.composite(new_fg_image, background, cropped_mask)
116
- img = composite
117
-
118
  img_t = test_transform(img).to(DEVICE)
 
 
119
  out = model(img_t[None, :])
 
 
120
  softmax = nn.Softmax(dim=1)
121
  out = softmax(out)
 
 
122
  label_classes=torch.tensor([0,1,2,3,4]).to(DEVICE)
 
 
123
  out = round((label_classes * out).sum(axis=1).item(),1)
124
 
125
  return img, out
126
 
 
127
  title = 'Design Modernity of Automobiles'
 
 
128
  description = "Demo for design modernity of automobiles. To use it, simply upload your image, or click one of the examples to load them."
 
 
129
  examples = [['input.jpg'],['input1.jpg']]
 
 
130
  interface = gr.Interface(modernity,inputs=i1, outputs=[o1, o2], title=title, description=description, examples=examples, cache_examples=False)
131
 
 
132
  interface.launch()
 
1
+ # general setup
2
  import os
3
  os.system('pip install detectron2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu102/torch1.9/index.html')
4
  os.system('pip install opencv-python')
5
 
6
+ # setup detectron2 logger
 
7
  import torch, detectron2
8
  from detectron2.utils.logger import setup_logger
9
  setup_logger()
 
11
  # import some common libraries
12
  import numpy as np
13
  import os, json, cv2
14
+ import pandas as pd
15
+ from PIL import Image
16
+ from torchvision import transforms
17
+ from torchvision import models
18
+ from torch import nn
19
 
20
  # import some common detectron2 utilities
21
  from detectron2 import model_zoo
 
24
  from detectron2.utils.visualizer import Visualizer
25
  from detectron2.data import MetadataCatalog, DatasetCatalog
26
 
27
+ # import gradio
 
 
 
 
 
 
28
  import gradio as gr
29
 
30
+ # set device
31
  DEVICE = 'cpu'
32
 
 
 
 
33
  # load model
34
  model = models.resnet18(pretrained=True)
35
  num_features = model.fc.in_features
 
38
  # insert trained paramters
39
  model.load_state_dict(torch.load('model_modernity.pth', map_location=torch.device('cpu')))
40
 
41
+ # enable model eval
42
  model.eval()
43
 
44
+ # define mean and std of resent training data
45
  mean = [0.485, 0.456, 0.406]
46
  std=[0.229, 0.224, 0.225]
47
 
48
+ # define transforms
49
  test_transform = transforms.Compose([
50
  transforms.Resize((224,224)),
51
  transforms.ToTensor(),
 
53
  std=std)
54
  ])
55
 
56
+ # define input and outputs
57
  i1 = gr.inputs.Image(type="numpy", label="Input image")
58
  o1 = gr.outputs.Image(type="pil", label="Cropped image")
59
  o2 = gr.outputs.Textbox(label="Modernity score")
60
 
61
+ # define function to be called by gradio interface
62
  def modernity(im):
63
+
64
+ # create detectron2 config and detectron2 DefaultPredictor to run inference on image
65
  cfg = get_cfg()
 
66
  cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
67
  cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 # set threshold for this model
 
68
  cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
69
  cfg.MODEL.DEVICE='cpu'
70
  predictor = DefaultPredictor(cfg)
71
  outputs = predictor(im)
72
 
73
+ # get all masks of input image
74
  masks = outputs['instances'].pred_masks.to('cpu').numpy()
75
 
76
+ # create empty lists for objects names and object sizes
77
  obj = []
78
  obj_size = []
79
 
80
+ # iterate over all detected objects in input image to obtain object names and object sizes
81
  for idx, data in enumerate(outputs['instances'].pred_classes):
82
  num = data.item()
83
  obj.append(MetadataCatalog.get(cfg.DATASETS.TRAIN[0]).thing_classes[num])
84
  obj_size.append(masks[idx].sum())
85
 
86
+ # define output if there is no automobile detected
87
  if 'car' not in obj:
88
+
89
+ # return image with all detected objects highlighted
90
  v = Visualizer(im[:, :, ::-1], MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), scale=1.2)
91
  out = v.draw_instance_predictions(outputs["instances"].to('cpu'))
92
  img = (out.get_image()[:, :, ::-1])
93
+
94
+ # return message
95
  out = 'No automobiles were found in the image.'
96
 
97
  else:
98
+
99
+ # create data frame containing all object names and sizes
100
  objects = pd.DataFrame({'obj': obj,
101
  'obj_size': obj_size})
102
+
103
+ # get mask of the largest object that is labeled as car
104
  item_mask = masks[objects[objects['obj'] == 'car']['obj_size'].idxmax()]
105
 
106
+ # create segmentation
107
  segmentation = np.where(item_mask == True)
108
 
109
+ # get x and y boundaries
110
  x_min = int(np.min(segmentation[1]))
111
  x_max = int(np.max(segmentation[1]))
112
  y_min = int(np.min(segmentation[0]))
113
  y_max = int(np.max(segmentation[0]))
114
 
115
+ # create cropped image
116
  cropped = Image.fromarray(im[y_min:y_max, x_min:x_max, :], mode='RGB')
117
 
118
+ # create mask
119
  mask = Image.fromarray((item_mask * 255).astype('uint8'))
120
+
121
+ # create cropped mask
122
  cropped_mask = mask.crop((x_min, y_min, x_max, y_max))
123
 
124
+ # create background
125
  background = Image.new(mode='RGB', size=cropped_mask.size, color='white')
126
+
127
+ # define paste position
128
  paste_position = (0,0)
129
 
130
+ # create foreground image
131
  new_fg_image = Image.new('RGB', background.size)
132
  new_fg_image.paste(cropped, paste_position)
133
+
134
+ # composite final image
135
+ img = Image.composite(new_fg_image, background, cropped_mask)
136
 
137
+ # apply previously defined transformations
 
 
138
  img_t = test_transform(img).to(DEVICE)
139
+
140
+ # feed transformed image to the model
141
  out = model(img_t[None, :])
142
+
143
+ # apply softmax
144
  softmax = nn.Softmax(dim=1)
145
  out = softmax(out)
146
+
147
+ # get label classes
148
  label_classes=torch.tensor([0,1,2,3,4]).to(DEVICE)
149
+
150
+ # compute modernity score
151
  out = round((label_classes * out).sum(axis=1).item(),1)
152
 
153
  return img, out
154
 
155
+ # set interface title
156
  title = 'Design Modernity of Automobiles'
157
+
158
+ # set interface description
159
  description = "Demo for design modernity of automobiles. To use it, simply upload your image, or click one of the examples to load them."
160
+
161
+ # include example images
162
  examples = [['input.jpg'],['input1.jpg']]
163
+
164
+ # define interface
165
  interface = gr.Interface(modernity,inputs=i1, outputs=[o1, o2], title=title, description=description, examples=examples, cache_examples=False)
166
 
167
+ # launch interface
168
  interface.launch()