ausawin commited on
Commit
749d98f
1 Parent(s): 58aab4a

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +145 -0
app.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ import io
4
+ import gradio as gr
5
+ import matplotlib.pyplot as plt
6
+ import requests, validators
7
+ from sqlalchemy import true
8
+ import torch
9
+ import pathlib
10
+ from PIL import Image
11
+ import os
12
+
13
+ from detecto import core, utils, visualize
14
+ from detecto.core import Model as DetectoModel
15
+
16
+
17
+
18
+ title = """<h1 id="title">AEYE INSPECTOR</h1>"""
19
+ css = '''
20
+ h1#title {
21
+ text-align: center;
22
+ }
23
+ '''
24
+ COLORS = [
25
+ [0.000, 0.447, 0.741],
26
+ [0.850, 0.325, 0.098],
27
+ [0.929, 0.694, 0.125],
28
+ [0.494, 0.184, 0.556],
29
+ [0.466, 0.674, 0.188],
30
+ [0.301, 0.745, 0.933]
31
+ ]
32
+
33
+ models = ["Detecto (Faster-RCNN)","YOLOv100"]
34
+ urls = [#'http://fbbbb.ddns.net:4080/static/images/ai_img(1).jpg',
35
+ #'http://fbbbb.ddns.net:4080/static/images/ai_img(2).jpg',
36
+ #'http://fbbbb.ddns.net:4080/static/images/ai_img(3).jpg',
37
+ #'http://fbbbb.ddns.net:4080/static/images/ai_img(4).jpg',
38
+ #'http://fbbbb.ddns.net:4080/static/images/ai_img(5).jpg',
39
+ #'http://fbbbb.ddns.net:4080/static/images/ai_img(6).jpg',
40
+ #'http://fbbbb.ddns.net:4080/static/images/ai_img(7).jpg',
41
+ #'http://fbbbb.ddns.net:4080/static/images/ai_img(8).jpg',
42
+ #'http://fbbbb.ddns.net:4080/static/images/ai_img(9).jpg',
43
+ #'http://fbbbb.ddns.net:4080/static/images/ai_img(10).jpg'
44
+ ]
45
+
46
+
47
+ def detect_objects(model_name,url_input,image_input,threshold):
48
+ if validators.url(url_input):
49
+ image = Image.open(requests.get(url_input, stream=True).raw)
50
+ elif image_input:
51
+ image = image_input
52
+
53
+ if 'Detecto' in model_name:
54
+ model = DetectoModel(['heltmet_safe','face_mask','safety_vest','safety_belts','safety_shoes'])
55
+ model.load('ai_30ep.pth',['heltmet_safe','face_mask','safety_vest','safety_belts','safety_shoes'])
56
+
57
+ print("OK")
58
+ labels, boxes, scores = model.predict(image)
59
+ viz_img = visualize_prediction(image, labels, boxes, scores, threshold)
60
+
61
+ print(labels)
62
+ #print(boxes)
63
+ print(scores)
64
+
65
+ return viz_img
66
+
67
+ def visualize_prediction(pil_img, labels, boxes, scores, threshold=0.7):
68
+ keeps = scores > threshold
69
+ print(keeps)
70
+
71
+ boxess = boxes[keeps].tolist()
72
+ print(boxess)
73
+
74
+ #labelss = labels[keep]
75
+ #print(labelss)
76
+
77
+ plt.figure(figsize=(16, 10))
78
+ plt.imshow(pil_img)
79
+ ax = plt.gca()
80
+ colors = COLORS * 100
81
+ for idx, keep in enumerate(keeps):
82
+ if keep:
83
+ (xmin, ymin, xmax, ymax) = zip(boxess[idx])
84
+ print(xmin[0])
85
+ print(ymin[0])
86
+ ax.add_patch(plt.Rectangle((xmin[0], ymin[0]), xmax[0] - xmin[0], ymax[0] - ymin[0], fill=False, color=colors[idx], linewidth=3))
87
+ ax.text(xmin[0], ymin[0], f'{labels[idx]}: {scores[idx]:0.2f}', fontsize=15, bbox=dict(facecolor="yellow", alpha=0.5))
88
+
89
+ plt.axis("off")
90
+ return fig2img(plt.gcf())
91
+
92
+ def set_example_image(example: list) -> dict:
93
+ return gr.Image.update(value=example[0])
94
+
95
+ def set_example_url(example: list) -> dict:
96
+ return gr.Textbox.update(value=example[0])
97
+
98
+ def fig2img(fig):
99
+ buf = io.BytesIO()
100
+ fig.savefig(buf)
101
+ buf.seek(0)
102
+ img = Image.open(buf)
103
+ return img
104
+
105
+
106
+
107
+ app = gr.Blocks(css=css)
108
+
109
+ with app:
110
+ gr.Markdown(title)
111
+ options = gr.Dropdown(choices=models,label='Select Object Detection Model',show_label=True)
112
+ slider_input = gr.Slider(minimum=0.2,maximum=1,value=0.7,label='Prediction Threshold')
113
+
114
+ with gr.Tabs():
115
+ with gr.TabItem('Image URL'):
116
+ with gr.Row():
117
+ url_input = gr.Textbox(lines=2,label='Enter valid image URL here..')
118
+ img_output_from_url = gr.Image(shape=(650,650))
119
+
120
+ with gr.Row():
121
+ example_url = gr.Dataset(components=[url_input],samples=[[str(url)] for url in urls])
122
+
123
+ url_but = gr.Button('Detect')
124
+
125
+
126
+ with gr.TabItem('Image Upload'):
127
+ with gr.Row():
128
+ img_input = gr.Image(type='pil')
129
+ img_output_from_upload= gr.Image(shape=(650,650))
130
+
131
+ with gr.Row():
132
+ example_images = gr.Dataset(components=[img_input],
133
+ samples=[[path.as_posix()]
134
+ for path in sorted(pathlib.Path('images').rglob('*.JPG'))])
135
+
136
+ img_but = gr.Button('Detect')
137
+
138
+
139
+ url_but.click(detect_objects,inputs=[options,url_input,img_input,slider_input],outputs=img_output_from_url,queue=True)
140
+ img_but.click(detect_objects,inputs=[options,url_input,img_input,slider_input],outputs=img_output_from_upload,queue=True)
141
+ example_images.click(fn=set_example_image,inputs=[example_images],outputs=[img_input])
142
+ example_url.click(fn=set_example_url,inputs=[example_url],outputs=[url_input])
143
+
144
+
145
+ app.launch(enable_queue=True, server_name='0.0.0.0', show_error=True)