Thanaphit commited on
Commit
3296110
·
1 Parent(s): a5e6cf9

Car Parts Segmentation App

Browse files
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ requirements.in
2
+ __pycache__
3
+ weight/weight.pt
app.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gdown
3
+ import gradio as gr
4
+ from utils import Predictor
5
+
6
+ PT_URL = "https://drive.google.com/uc?id=1I_LKds9obElNIZcW_DM8zyknrwRmrASj"
7
+ PT_OUT = "weight/weights.pt"
8
+
9
+ SAMPLE = [
10
+ [f"{root}/{file}"] \
11
+ for root, _, files in os.walk("sample", topdown=False) \
12
+ for file in files
13
+ ]
14
+
15
+ if not os.path.exists(PT_OUT):
16
+ gdown.download(PT_URL, PT_OUT, quiet=True)
17
+
18
+ predictor = Predictor(PT_OUT)
19
+
20
+ box_ui = gr.Interface(
21
+ fn=predictor.annotate_boxes,
22
+ inputs=[ gr.components.Image(type="filepath", label="Input Image") ],
23
+ outputs=[ gr.components.Image(type="numpy", label="Output Image") ],
24
+ title="Car parts segmentation",
25
+ examples=SAMPLE,
26
+ cache_examples=False,
27
+ )
28
+
29
+ mask_ui = gr.Interface(
30
+ fn=predictor.annotate_masks,
31
+ inputs=[ gr.components.Image(type="filepath", label="Input Image") ],
32
+ outputs=[ gr.components.Image(type="numpy", label="Output Image") ],
33
+ title="Car parts segmentation",
34
+ examples=SAMPLE,
35
+ cache_examples=False,
36
+ )
37
+
38
+ gr.TabbedInterface(
39
+ [box_ui, mask_ui],
40
+ tab_names=['Bouding Boxes inference', "Masks inference"]
41
+ ).queue().launch()
requirements.txt ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # This file is autogenerated by pip-compile with Python 3.9
3
+ # by the following command:
4
+ #
5
+ # pip-compile requirements.in
6
+ #
7
+ aiofiles==23.2.0
8
+ # via gradio
9
+ aiohttp==3.8.5
10
+ # via gradio
11
+ aiosignal==1.3.1
12
+ # via aiohttp
13
+ altair==5.0.1
14
+ # via gradio
15
+ annotated-types==0.5.0
16
+ # via pydantic
17
+ anyio==3.7.1
18
+ # via
19
+ # httpcore
20
+ # starlette
21
+ async-timeout==4.0.2
22
+ # via aiohttp
23
+ attrs==23.1.0
24
+ # via
25
+ # aiohttp
26
+ # jsonschema
27
+ # referencing
28
+ beautifulsoup4==4.12.2
29
+ # via gdown
30
+ certifi==2023.7.22
31
+ # via
32
+ # httpcore
33
+ # httpx
34
+ # requests
35
+ charset-normalizer==3.2.0
36
+ # via
37
+ # aiohttp
38
+ # requests
39
+ click==8.1.6
40
+ # via uvicorn
41
+ contourpy==1.1.0
42
+ # via matplotlib
43
+ cycler==0.11.0
44
+ # via matplotlib
45
+ exceptiongroup==1.1.2
46
+ # via anyio
47
+ fastapi==0.101.0
48
+ # via gradio
49
+ ffmpy==0.3.1
50
+ # via gradio
51
+ filelock==3.12.2
52
+ # via
53
+ # gdown
54
+ # huggingface-hub
55
+ # torch
56
+ fonttools==4.42.0
57
+ # via matplotlib
58
+ frozenlist==1.4.0
59
+ # via
60
+ # aiohttp
61
+ # aiosignal
62
+ fsspec==2023.6.0
63
+ # via
64
+ # gradio-client
65
+ # huggingface-hub
66
+ gdown==4.7.1
67
+ # via -r requirements.in
68
+ gradio==3.39.0
69
+ # via -r requirements.in
70
+ gradio-client==0.3.0
71
+ # via gradio
72
+ h11==0.14.0
73
+ # via
74
+ # httpcore
75
+ # uvicorn
76
+ httpcore==0.17.3
77
+ # via httpx
78
+ httpx==0.24.1
79
+ # via
80
+ # gradio
81
+ # gradio-client
82
+ huggingface-hub==0.16.4
83
+ # via
84
+ # gradio
85
+ # gradio-client
86
+ idna==3.4
87
+ # via
88
+ # anyio
89
+ # httpx
90
+ # requests
91
+ # yarl
92
+ importlib-resources==6.0.1
93
+ # via matplotlib
94
+ jinja2==3.1.2
95
+ # via
96
+ # altair
97
+ # gradio
98
+ # torch
99
+ jsonschema==4.19.0
100
+ # via altair
101
+ jsonschema-specifications==2023.7.1
102
+ # via jsonschema
103
+ kiwisolver==1.4.4
104
+ # via matplotlib
105
+ linkify-it-py==2.0.2
106
+ # via markdown-it-py
107
+ markdown-it-py[linkify]==2.2.0
108
+ # via
109
+ # gradio
110
+ # mdit-py-plugins
111
+ markupsafe==2.1.3
112
+ # via
113
+ # gradio
114
+ # jinja2
115
+ matplotlib==3.7.2
116
+ # via
117
+ # gradio
118
+ # seaborn
119
+ # ultralytics
120
+ mdit-py-plugins==0.3.3
121
+ # via gradio
122
+ mdurl==0.1.2
123
+ # via markdown-it-py
124
+ mpmath==1.3.0
125
+ # via sympy
126
+ multidict==6.0.4
127
+ # via
128
+ # aiohttp
129
+ # yarl
130
+ networkx==3.1
131
+ # via torch
132
+ numpy==1.25.2
133
+ # via
134
+ # -r requirements.in
135
+ # altair
136
+ # contourpy
137
+ # gradio
138
+ # matplotlib
139
+ # opencv-python
140
+ # pandas
141
+ # scipy
142
+ # seaborn
143
+ # torchvision
144
+ # ultralytics
145
+ opencv-python==4.8.0.74
146
+ # via ultralytics
147
+ orjson==3.9.4
148
+ # via gradio
149
+ packaging==23.1
150
+ # via
151
+ # gradio
152
+ # gradio-client
153
+ # huggingface-hub
154
+ # matplotlib
155
+ pandas==2.0.3
156
+ # via
157
+ # -r requirements.in
158
+ # altair
159
+ # gradio
160
+ # seaborn
161
+ # ultralytics
162
+ pillow==10.0.0
163
+ # via
164
+ # gradio
165
+ # matplotlib
166
+ # torchvision
167
+ # ultralytics
168
+ psutil==5.9.5
169
+ # via ultralytics
170
+ py-cpuinfo==9.0.0
171
+ # via ultralytics
172
+ pydantic==2.1.1
173
+ # via
174
+ # fastapi
175
+ # gradio
176
+ pydantic-core==2.4.0
177
+ # via pydantic
178
+ pydub==0.25.1
179
+ # via gradio
180
+ pyparsing==3.0.9
181
+ # via matplotlib
182
+ pysocks==1.7.1
183
+ # via requests
184
+ python-dateutil==2.8.2
185
+ # via
186
+ # matplotlib
187
+ # pandas
188
+ python-multipart==0.0.6
189
+ # via gradio
190
+ pytz==2023.3
191
+ # via pandas
192
+ pyyaml==6.0.1
193
+ # via
194
+ # gradio
195
+ # huggingface-hub
196
+ # ultralytics
197
+ referencing==0.30.2
198
+ # via
199
+ # jsonschema
200
+ # jsonschema-specifications
201
+ requests[socks]==2.31.0
202
+ # via
203
+ # gdown
204
+ # gradio
205
+ # gradio-client
206
+ # huggingface-hub
207
+ # torchvision
208
+ # ultralytics
209
+ rpds-py==0.9.2
210
+ # via
211
+ # jsonschema
212
+ # referencing
213
+ scipy==1.11.1
214
+ # via ultralytics
215
+ seaborn==0.12.2
216
+ # via ultralytics
217
+ semantic-version==2.10.0
218
+ # via gradio
219
+ six==1.16.0
220
+ # via
221
+ # gdown
222
+ # python-dateutil
223
+ sniffio==1.3.0
224
+ # via
225
+ # anyio
226
+ # httpcore
227
+ # httpx
228
+ soupsieve==2.4.1
229
+ # via beautifulsoup4
230
+ starlette==0.27.0
231
+ # via fastapi
232
+ sympy==1.12
233
+ # via torch
234
+ toolz==0.12.0
235
+ # via altair
236
+ torch==2.0.1
237
+ # via
238
+ # torchvision
239
+ # ultralytics
240
+ torchvision==0.15.2
241
+ # via ultralytics
242
+ tqdm==4.66.0
243
+ # via
244
+ # gdown
245
+ # huggingface-hub
246
+ # ultralytics
247
+ typing-extensions==4.7.1
248
+ # via
249
+ # altair
250
+ # fastapi
251
+ # gradio
252
+ # gradio-client
253
+ # huggingface-hub
254
+ # pydantic
255
+ # pydantic-core
256
+ # starlette
257
+ # torch
258
+ # uvicorn
259
+ tzdata==2023.3
260
+ # via pandas
261
+ uc-micro-py==1.0.2
262
+ # via linkify-it-py
263
+ ultralytics==8.0.150
264
+ # via -r requirements.in
265
+ urllib3==2.0.4
266
+ # via requests
267
+ uvicorn==0.23.2
268
+ # via gradio
269
+ websockets==11.0.3
270
+ # via
271
+ # gradio
272
+ # gradio-client
273
+ yarl==1.9.2
274
+ # via aiohttp
275
+ zipp==3.16.2
276
+ # via importlib-resources
sample/sample-1.jpg ADDED
sample/sample-2.jpg ADDED
sample/sample-3.jpg ADDED
sample/sample-4.jpg ADDED
sample/sample-5.jpg ADDED
sample/sample-6.jpg ADDED
utils/__init__.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ from matplotlib import colors
5
+ from ultralytics import YOLO
6
+ from ultralytics.utils.ops import scale_image
7
+
8
+ class Predictor:
9
+
10
+ def __init__(self, model_weight):
11
+ self.model = YOLO(model_weight)
12
+ self.category_map = self.model.names
13
+ self.NCLS = len(self.category_map)
14
+
15
+ cmap = plt.cm.rainbow
16
+ cmaplist = [cmap(i) for i in range(cmap.N)]
17
+
18
+ self.cmap = cmap.from_list(f'my cmap', cmaplist, cmap.N)
19
+
20
+ bounds = np.linspace(0, self.NCLS, self.NCLS + 1)
21
+ norm = colors.BoundaryNorm(bounds, self.cmap.N)
22
+
23
+ category_cmap = { k: cmap(norm(int(k))) for k in self.category_map }
24
+ self.category_cmap = { k: (v[2] * 255, v[1] * 255, v[0]* 255) \
25
+ for k, v in category_cmap.items() }
26
+
27
+ def predict(self, image_path):
28
+ image = cv2.imread(image_path)
29
+ outputs = self.model.predict(source=image_path)
30
+ results = outputs[0].cpu().numpy()
31
+
32
+ boxes = results.boxes.xyxy
33
+ confs = results.boxes.conf
34
+ cls = results.boxes.cls
35
+ # probs = results.boxes.probs
36
+ masks = results.masks.data
37
+
38
+ return image, cls, confs, boxes, masks, results
39
+
40
+ def annotate_boxes(self, image_path):
41
+ image, cls, confs, boxes, _, results = self.predict(image_path)
42
+
43
+ for i, (box, cl, conf) in enumerate(zip(boxes, cls, confs)):
44
+ label = results.names[cl]
45
+ color = self.category_cmap[cl]
46
+ text = label + f" {conf:.2f}"
47
+ x1, y1, x2, y2 = ( int(p) for p in box )
48
+
49
+ cv2.rectangle(image, (x1, y1), (x2, y2),
50
+ color=color,
51
+ thickness=2,
52
+ lineType=cv2.LINE_AA
53
+ )
54
+ (w, h), _ = cv2.getTextSize(text, cv2.FONT_HERSHEY_DUPLEX, 0.3, 1)
55
+ cv2.rectangle(image, (x1, y1 - 2*h), (x1 + w, y1), color, -1)
56
+ cv2.putText(image, text, (x1, y1 - 5),
57
+ cv2.FONT_HERSHEY_DUPLEX, 0.3, (255, 255, 255), 1)
58
+
59
+ return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
60
+
61
+ def annotate_masks(self, image_path):
62
+ image, cls, confs, _, masks, results = self.predict(image_path)
63
+ ori_shape = image.shape[:2]
64
+
65
+ for i, (mask, cl, conf) in enumerate(zip(masks, cls, confs)):
66
+ mask = mask.astype("uint8")
67
+ label = results.names[cl]
68
+ color = self.category_cmap[cl]
69
+ text = label + f" {conf:.2f}"
70
+
71
+ _mask = np.where(mask[..., None], color, (0, 0, 0))
72
+ _mask = scale_image(_mask, ori_shape).astype("uint8")
73
+ image = cv2.addWeighted(image, 1, _mask, 0.5, 0)
74
+
75
+ contours, _ = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
76
+ boundary = cv2.cvtColor(mask, cv2.COLOR_GRAY2RGBA).astype("float")
77
+ cv2.drawContours(boundary, contours, -1, color, 2)
78
+ boundary = scale_image(boundary, ori_shape)[:, :, :-1].astype("uint8")
79
+ image = cv2.addWeighted(image, 1, boundary, 1, 0)
80
+
81
+ cy, cx = np.round(np.argwhere(_mask != [0, 0, 0]).mean(axis=0)[0:2]).astype(int)
82
+ (w, h), _ = cv2.getTextSize(text, cv2.FONT_HERSHEY_DUPLEX, 0.5, 1)
83
+
84
+ cv2.putText(image, text, (cx - int(0.5 * w), cy),
85
+ cv2.FONT_HERSHEY_DUPLEX, 0.5, (0, 0, 0), 2)
86
+ cv2.putText(image, text, (cx - int(0.5 * w), cy),
87
+ cv2.FONT_HERSHEY_DUPLEX, 0.5, (255, 255, 255), 1)
88
+
89
+ return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)