vibha-mah commited on
Commit
cdb28db
·
1 Parent(s): a70de84

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +86 -29
app.py CHANGED
@@ -1,7 +1,7 @@
1
  try:
2
  import detectron2
3
  except:
4
- import os
5
  os.system('pip install git+https://github.com/facebookresearch/detectron2.git')
6
 
7
  from matplotlib.pyplot import axis
@@ -10,6 +10,7 @@ import requests
10
  import numpy as np
11
  from torch import nn
12
  import requests
 
13
 
14
  import torch
15
  import detectron2
@@ -19,6 +20,8 @@ from detectron2.config import get_cfg
19
  from detectron2.utils.visualizer import Visualizer
20
  from detectron2.data import MetadataCatalog
21
  from detectron2.utils.visualizer import ColorMode
 
 
22
 
23
  damage_model_path = 'damage/model_final.pth'
24
  scratch_model_path = 'scratch/model_final.pth'
@@ -28,7 +31,7 @@ if torch.cuda.is_available():
28
  device = 'cuda'
29
  else:
30
  device = 'cpu'
31
-
32
  cfg_scratches = get_cfg()
33
  cfg_scratches.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
34
  cfg_scratches.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.8
@@ -83,6 +86,7 @@ metadata_parts.thing_classes = ['_background_',
83
  'trunk',
84
  'wheel']
85
 
 
86
  def merge_segment(pred_segm):
87
  merge_dict = {}
88
  for i in range(len(pred_segm)):
@@ -90,36 +94,35 @@ def merge_segment(pred_segm):
90
  for j in range(i+1,len(pred_segm)):
91
  if torch.sum(pred_segm[i]*pred_segm[j])>0:
92
  merge_dict[i].append(j)
93
-
94
  to_delete = []
95
  for key in merge_dict:
96
  for element in merge_dict[key]:
97
  to_delete.append(element)
98
-
99
  for element in to_delete:
100
  merge_dict.pop(element,None)
101
-
102
  empty_delete = []
103
  for key in merge_dict:
104
  if merge_dict[key] == []:
105
  empty_delete.append(key)
106
-
107
  for element in empty_delete:
108
  merge_dict.pop(element,None)
109
-
110
  for key in merge_dict:
111
  for element in merge_dict[key]:
112
  pred_segm[key]+=pred_segm[element]
113
-
114
  except_elem = list(set(to_delete))
115
-
116
  new_indexes = list(range(len(pred_segm)))
117
  for elem in except_elem:
118
  new_indexes.remove(elem)
119
-
120
  return pred_segm[new_indexes]
121
 
122
-
123
  def inference(image):
124
  img = np.array(image)
125
  outputs_damage = predictor_damage(img)
@@ -136,7 +139,8 @@ def inference(image):
136
  parts_classes = parts_data['pred_classes']
137
  new_inst = detectron2.structures.Instances((1024,1024))
138
  new_inst.set('pred_masks',merge_segment(out_dict['pred_masks']))
139
-
 
140
  parts_damage_dict = {}
141
  parts_list_damages = []
142
  for part in parts_classes:
@@ -145,7 +149,7 @@ def inference(image):
145
  for i in range(len(parts_masks)):
146
  if torch.sum(parts_masks[i]*mask)>0:
147
  parts_damage_dict[metadata_parts.thing_classes[parts_classes[i]]].append('scratch')
148
- parts_list_damages.append(f'{metadata_parts.thing_classes[parts_classes[i]]} has scratch')
149
  print(f'{metadata_parts.thing_classes[parts_classes[i]]} has scratch')
150
  for mask in merged_damage_masks:
151
  for i in range(len(parts_masks)):
@@ -154,35 +158,86 @@ def inference(image):
154
  parts_list_damages.append(f'{metadata_parts.thing_classes[parts_classes[i]]} has damage')
155
  print(f'{metadata_parts.thing_classes[parts_classes[i]]} has damage')
156
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  v_d = Visualizer(img[:, :, ::-1],
158
- metadata=metadata_damage,
159
- scale=0.5,
160
  instance_mode=ColorMode.SEGMENTATION # remove the colors of unsegmented pixels. This option is only available for segmentation models
161
  )
162
- #v_d = Visualizer(img,scale=1.2)
163
- #print(outputs["instances"].to('cpu'))
164
  out_d = v_d.draw_instance_predictions(new_inst)
165
  img1 = out_d.get_image()[:, :, ::-1]
166
 
167
  v_s = Visualizer(img[:, :, ::-1],
168
- metadata=metadata_scratch,
169
- scale=0.5,
170
  instance_mode=ColorMode.SEGMENTATION # remove the colors of unsegmented pixels. This option is only available for segmentation models
171
  )
172
- #v_s = Visualizer(img,scale=1.2)
173
  out_s = v_s.draw_instance_predictions(outputs_scratch["instances"])
174
  img2 = out_s.get_image()[:, :, ::-1]
175
 
176
  v_p = Visualizer(img[:, :, ::-1],
177
- metadata=metadata_parts,
178
- scale=0.5,
179
  instance_mode=ColorMode.SEGMENTATION # remove the colors of unsegmented pixels. This option is only available for segmentation models
180
  )
181
- #v_p = Visualizer(img,scale=1.2)
182
  out_p = v_p.draw_instance_predictions(outputs_parts["instances"])
183
  img3 = out_p.get_image()[:, :, ::-1]
184
-
185
- return img1, img2, img3, parts_list_damages
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
 
187
 
188
  with gr.Blocks() as demo:
@@ -201,13 +256,15 @@ with gr.Blocks() as demo:
201
  im3 = gr.Image(type='numpy',label='Image of car parts')
202
  with gr.Tab('Information about damaged parts'):
203
  intersections = gr.Textbox(label='Information about type of damages on each part')
204
-
 
 
205
  #actions
206
  submit_button.click(
207
  fn=inference,
208
  inputs = [image],
209
- outputs = [im1,im2,im3,intersections]
210
  )
211
-
212
  if __name__ == "__main__":
213
- demo.launch()
 
1
  try:
2
  import detectron2
3
  except:
4
+ import os
5
  os.system('pip install git+https://github.com/facebookresearch/detectron2.git')
6
 
7
  from matplotlib.pyplot import axis
 
10
  import numpy as np
11
  from torch import nn
12
  import requests
13
+ import cv2
14
 
15
  import torch
16
  import detectron2
 
20
  from detectron2.utils.visualizer import Visualizer
21
  from detectron2.data import MetadataCatalog
22
  from detectron2.utils.visualizer import ColorMode
23
+ from detectron2.structures import Instances
24
+ from detectron2.structures import Boxes
25
 
26
  damage_model_path = 'damage/model_final.pth'
27
  scratch_model_path = 'scratch/model_final.pth'
 
31
  device = 'cuda'
32
  else:
33
  device = 'cpu'
34
+
35
  cfg_scratches = get_cfg()
36
  cfg_scratches.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
37
  cfg_scratches.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.8
 
86
  'trunk',
87
  'wheel']
88
 
89
+
90
  def merge_segment(pred_segm):
91
  merge_dict = {}
92
  for i in range(len(pred_segm)):
 
94
  for j in range(i+1,len(pred_segm)):
95
  if torch.sum(pred_segm[i]*pred_segm[j])>0:
96
  merge_dict[i].append(j)
97
+
98
  to_delete = []
99
  for key in merge_dict:
100
  for element in merge_dict[key]:
101
  to_delete.append(element)
102
+
103
  for element in to_delete:
104
  merge_dict.pop(element,None)
105
+
106
  empty_delete = []
107
  for key in merge_dict:
108
  if merge_dict[key] == []:
109
  empty_delete.append(key)
110
+
111
  for element in empty_delete:
112
  merge_dict.pop(element,None)
113
+
114
  for key in merge_dict:
115
  for element in merge_dict[key]:
116
  pred_segm[key]+=pred_segm[element]
117
+
118
  except_elem = list(set(to_delete))
119
+
120
  new_indexes = list(range(len(pred_segm)))
121
  for elem in except_elem:
122
  new_indexes.remove(elem)
123
+
124
  return pred_segm[new_indexes]
125
 
 
126
  def inference(image):
127
  img = np.array(image)
128
  outputs_damage = predictor_damage(img)
 
139
  parts_classes = parts_data['pred_classes']
140
  new_inst = detectron2.structures.Instances((1024,1024))
141
  new_inst.set('pred_masks',merge_segment(out_dict['pred_masks']))
142
+
143
+
144
  parts_damage_dict = {}
145
  parts_list_damages = []
146
  for part in parts_classes:
 
149
  for i in range(len(parts_masks)):
150
  if torch.sum(parts_masks[i]*mask)>0:
151
  parts_damage_dict[metadata_parts.thing_classes[parts_classes[i]]].append('scratch')
152
+ parts_list_damages.append(f'{metadata_parts.thing_classes[parts_classes[i]]} has scratch')
153
  print(f'{metadata_parts.thing_classes[parts_classes[i]]} has scratch')
154
  for mask in merged_damage_masks:
155
  for i in range(len(parts_masks)):
 
158
  parts_list_damages.append(f'{metadata_parts.thing_classes[parts_classes[i]]} has damage')
159
  print(f'{metadata_parts.thing_classes[parts_classes[i]]} has damage')
160
 
161
+ # Define the colors for the scratch and damage masks
162
+ scratch_color = (0, 0, 255) # red
163
+ damage_color = (0, 255, 255) # yellow
164
+ # Convert the scratch and damage masks to numpy arrays
165
+ scratch_masks_arr = np.array(scratch_masks)
166
+ damage_masks_arr = np.array(damage_masks)
167
+ # Resize the scratch and damage masks to match the size of the original image
168
+ scratch_mask_resized = cv2.resize(scratch_masks_arr[0].astype(np.uint8), (img.shape[1], img.shape[0]))
169
+ damage_mask_resized = cv2.resize(damage_masks_arr[0].astype(np.uint8), (img.shape[1], img.shape[0]))
170
+ # Merge the scratch and damage masks into a single binary mask
171
+ merged_mask = np.zeros_like(scratch_mask_resized)
172
+ merged_mask[(scratch_mask_resized> 0) | (damage_mask_resized > 0)] = 255
173
+ # Overlay the merged mask on top of the original image
174
+ overlay = img.copy()
175
+ overlay[merged_mask == 255] = (0, 255, 0) # green color for the merged mask
176
+ overlay[damage_mask_resized == 255] = damage_color # yellow color for the damage mask
177
+ #output = cv2.addWeighted(overlay, 0.5, img, 0.5, 0)
178
+
179
+ # Merge the instance predictions from both predictors
180
+ image_np = np.array(image)
181
+ height, width, channels = image_np.shape
182
+
183
+ # Get the predicted boxes from the scratches predictor
184
+ pred_boxes_scratch = outputs_scratch["instances"].pred_boxes.tensor
185
+
186
+ # Get the predicted boxes from the damage predictor
187
+ pred_boxes_damage = outputs_damage["instances"].pred_boxes.tensor
188
+
189
+ # Concatenate the predicted boxes along the batch dimension
190
+ merged_boxes = torch.cat([pred_boxes_scratch, pred_boxes_damage], dim=0)
191
+
192
+ # Create a new Instances object with the merged boxes
193
+ merged_instances = Instances((image_np.shape[0], image_np.shape[1]))
194
+ merged_instances.pred_boxes = Boxes(merged_boxes)
195
+
196
+
197
+ # Visualize the Masks
198
  v_d = Visualizer(img[:, :, ::-1],
199
+ metadata=metadata_damage,
200
+ scale=0.5,
201
  instance_mode=ColorMode.SEGMENTATION # remove the colors of unsegmented pixels. This option is only available for segmentation models
202
  )
203
+ v_d = Visualizer(img,scale=1.2)
 
204
  out_d = v_d.draw_instance_predictions(new_inst)
205
  img1 = out_d.get_image()[:, :, ::-1]
206
 
207
  v_s = Visualizer(img[:, :, ::-1],
208
+ metadata=metadata_scratch,
209
+ scale=0.5,
210
  instance_mode=ColorMode.SEGMENTATION # remove the colors of unsegmented pixels. This option is only available for segmentation models
211
  )
212
+ v_s = Visualizer(img,scale=1.2)
213
  out_s = v_s.draw_instance_predictions(outputs_scratch["instances"])
214
  img2 = out_s.get_image()[:, :, ::-1]
215
 
216
  v_p = Visualizer(img[:, :, ::-1],
217
+ metadata=metadata_parts,
218
+ scale=0.5,
219
  instance_mode=ColorMode.SEGMENTATION # remove the colors of unsegmented pixels. This option is only available for segmentation models
220
  )
221
+ v_p = Visualizer(img,scale=1.2)
222
  out_p = v_p.draw_instance_predictions(outputs_parts["instances"])
223
  img3 = out_p.get_image()[:, :, ::-1]
224
+
225
+ # Visualize the overlay
226
+ v_m = Visualizer(overlay[:, :, ::-1],
227
+ metadata=metadata_damage,
228
+ scale=1.2,
229
+ instance_mode=ColorMode.SEGMENTATION # display the overlay in black and white
230
+ )
231
+ # Draw the overlay with instance predictions
232
+ overlay_with_predictions = v_m.draw_instance_predictions(merged_instances.to("cpu")).get_image()[:, :, ::-1]
233
+ #v_m = Visualizer(overlay,scale=1.2)
234
+ out = v_m.draw_instance_predictions(merged_instances)
235
+ output = out.get_image()[:, :, ::-1]
236
+
237
+
238
+
239
+
240
+ return img1, img2, img3, parts_list_damages, output
241
 
242
 
243
  with gr.Blocks() as demo:
 
256
  im3 = gr.Image(type='numpy',label='Image of car parts')
257
  with gr.Tab('Information about damaged parts'):
258
  intersections = gr.Textbox(label='Information about type of damages on each part')
259
+ with gr.Tab('Image of overlayed damage parts'):
260
+ overlayed = gr.Image(type='numpy',label='Image of overlayed damage parts')
261
+
262
  #actions
263
  submit_button.click(
264
  fn=inference,
265
  inputs = [image],
266
+ outputs = [im1,im2,im3,intersections, overlayed]
267
  )
268
+
269
  if __name__ == "__main__":
270
+ demo.launch()