intelliarts commited on
Commit
7f798e2
1 Parent(s): 957c8b6

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +105 -0
app.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ try:
2
+ import detectron2
3
+ except:
4
+ import os
5
+ os.system('pip install git+https://github.com/facebookresearch/detectron2.git')
6
+
7
+ from matplotlib.pyplot import axis
8
+ import gradio as gr
9
+ import requests
10
+ import numpy as np
11
+ from torch import nn
12
+ import requests
13
+
14
+ import torch
15
+ import detectron2
16
+ from detectron2 import model_zoo
17
+ from detectron2.engine import DefaultPredictor
18
+ from detectron2.config import get_cfg
19
+ from detectron2.utils.visualizer import Visualizer
20
+ from detectron2.data import MetadataCatalog
21
+ from detectron2.utils.visualizer import ColorMode
22
+
23
+ model_path = 'model_final.pth'
24
+
25
+ cfg = get_cfg()
26
+ cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
27
+ cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.6
28
+ cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1
29
+ cfg.MODEL.WEIGHTS = model_path
30
+
31
+ if not torch.cuda.is_available():
32
+ cfg.MODEL.DEVICE='cpu'
33
+
34
+ predictor = DefaultPredictor(cfg)
35
+ my_metadata = MetadataCatalog.get("car_dataset_val")
36
+ my_metadata.thing_classes = ["damage"]
37
+
38
+ def merge_segment(pred_segm):
39
+ merge_dict = {}
40
+ for i in range(len(pred_segm)):
41
+ merge_dict[i] = []
42
+ for j in range(i+1,len(pred_segm)):
43
+ if torch.sum(pred_segm[i]*pred_segm[j])>0:
44
+ merge_dict[i].append(j)
45
+
46
+ to_delete = []
47
+ for key in merge_dict:
48
+ for element in merge_dict[key]:
49
+ to_delete.append(element)
50
+
51
+ for element in to_delete:
52
+ merge_dict.pop(element,None)
53
+
54
+ empty_delete = []
55
+ for key in merge_dict:
56
+ if merge_dict[key] == []:
57
+ empty_delete.append(key)
58
+
59
+ for element in empty_delete:
60
+ merge_dict.pop(element,None)
61
+
62
+ for key in merge_dict:
63
+ for element in merge_dict[key]:
64
+ pred_segm[key]+=pred_segm[element]
65
+
66
+ except_elem = list(set(to_delete))
67
+
68
+ new_indexes = list(range(len(pred_segm)))
69
+ for elem in except_elem:
70
+ new_indexes.remove(elem)
71
+
72
+ return pred_segm[new_indexes]
73
+
74
+ def inference(image):
75
+ print(image.height)
76
+
77
+ height = image.height
78
+
79
+ # img = np.array(image.resize((500, height)))
80
+ img = np.array(image)
81
+ outputs = predictor(img)
82
+ out_dict = outputs["instances"].to("cpu").get_fields()
83
+ new_inst = detectron2.structures.Instances((1024,1024))
84
+ new_inst.set('pred_masks',merge_segment(out_dict['pred_masks']))
85
+ v = Visualizer(img[:, :, ::-1],
86
+ metadata=my_metadata,
87
+ scale=0.5,
88
+ instance_mode=ColorMode.SEGMENTATION # remove the colors of unsegmented pixels. This option is only available for segmentation models
89
+ )
90
+ # v = Visualizer(img,scale=1.2)
91
+ #print(outputs["instances"].to('cpu'))
92
+ out = v.draw_instance_predictions(new_inst)
93
+
94
+ return out.get_image()[:, :, ::-1]
95
+
96
+ title = "Detectron2 Car damage Detection"
97
+ description = "This demo introduces an interactive playground for our trained Detectron2 model."
98
+
99
+ gr.Interface(
100
+ inference,
101
+ [gr.inputs.Image(type="pil", label="Input")],
102
+ gr.outputs.Image(type="numpy", label="Output"),
103
+ title=title,
104
+ description=description,
105
+ examples=[]).launch()