Bounding box detection
PyTorch
Ontocord.AI commited on
Commit
6dd6cd6
1 Parent(s): bd2b542

Create visualizing_image.py

Browse files
Files changed (1) hide show
  1. visualizing_image.py +496 -0
visualizing_image.py ADDED
@@ -0,0 +1,496 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ coding=utf-8
3
+ Copyright 2018, Antonio Mendoza Hao Tan, Mohit Bansal
4
+ Adapted From Facebook Inc, Detectron2
5
+ Licensed under the Apache License, Version 2.0 (the "License");
6
+ you may not use this file except in compliance with the License.
7
+ You may obtain a copy of the License at
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.import copy
14
+ """
15
+ import colorsys
16
+ import io
17
+
18
+ import matplotlib as mpl
19
+ import matplotlib.colors as mplc
20
+ import matplotlib.figure as mplfigure
21
+ import numpy as np
22
+ import torch
23
+ from matplotlib.backends.backend_agg import FigureCanvasAgg
24
+
25
+ import cv2
26
+ from .utils import img_tensorize
27
+
28
+
29
+ _SMALL_OBJ = 1000
30
+
31
+
32
+ class SingleImageViz:
33
+ def __init__(
34
+ self,
35
+ img,
36
+ scale=1.2,
37
+ edgecolor="g",
38
+ alpha=0.5,
39
+ linestyle="-",
40
+ saveas="test_out.jpg",
41
+ rgb=True,
42
+ pynb=False,
43
+ id2obj=None,
44
+ id2attr=None,
45
+ pad=0.7,
46
+ ):
47
+ """
48
+ img: an RGB image of shape (H, W, 3).
49
+ """
50
+ if isinstance(img, torch.Tensor):
51
+ img = img.numpy().astype("np.uint8")
52
+ if isinstance(img, str):
53
+ img = img_tensorize(img)
54
+ assert isinstance(img, np.ndarray)
55
+
56
+ width, height = img.shape[1], img.shape[0]
57
+ fig = mplfigure.Figure(frameon=False)
58
+ dpi = fig.get_dpi()
59
+ width_in = (width * scale + 1e-2) / dpi
60
+ height_in = (height * scale + 1e-2) / dpi
61
+ fig.set_size_inches(width_in, height_in)
62
+ ax = fig.add_axes([0.0, 0.0, 1.0, 1.0])
63
+ ax.axis("off")
64
+ ax.set_xlim(0.0, width)
65
+ ax.set_ylim(height)
66
+
67
+ self.saveas = saveas
68
+ self.rgb = rgb
69
+ self.pynb = pynb
70
+ self.img = img
71
+ self.edgecolor = edgecolor
72
+ self.alpha = 0.5
73
+ self.linestyle = linestyle
74
+ self.font_size = int(np.sqrt(min(height, width)) * scale // 3)
75
+ self.width = width
76
+ self.height = height
77
+ self.scale = scale
78
+ self.fig = fig
79
+ self.ax = ax
80
+ self.pad = pad
81
+ self.id2obj = id2obj
82
+ self.id2attr = id2attr
83
+ self.canvas = FigureCanvasAgg(fig)
84
+
85
+ def add_box(self, box, color=None):
86
+ if color is None:
87
+ color = self.edgecolor
88
+ (x0, y0, x1, y1) = box
89
+ width = x1 - x0
90
+ height = y1 - y0
91
+ self.ax.add_patch(
92
+ mpl.patches.Rectangle(
93
+ (x0, y0),
94
+ width,
95
+ height,
96
+ fill=False,
97
+ edgecolor=color,
98
+ linewidth=self.font_size // 3,
99
+ alpha=self.alpha,
100
+ linestyle=self.linestyle,
101
+ )
102
+ )
103
+
104
+ def draw_boxes(self, boxes, obj_ids=None, obj_scores=None, attr_ids=None, attr_scores=None):
105
+ if len(boxes.shape) > 2:
106
+ boxes = boxes[0]
107
+ if len(obj_ids.shape) > 1:
108
+ obj_ids = obj_ids[0]
109
+ if len(obj_scores.shape) > 1:
110
+ obj_scores = obj_scores[0]
111
+ if len(attr_ids.shape) > 1:
112
+ attr_ids = attr_ids[0]
113
+ if len(attr_scores.shape) > 1:
114
+ attr_scores = attr_scores[0]
115
+ if isinstance(boxes, torch.Tensor):
116
+ boxes = boxes.numpy()
117
+ if isinstance(boxes, list):
118
+ boxes = np.array(boxes)
119
+ assert isinstance(boxes, np.ndarray)
120
+ areas = np.prod(boxes[:, 2:] - boxes[:, :2], axis=1)
121
+ sorted_idxs = np.argsort(-areas).tolist()
122
+ boxes = boxes[sorted_idxs] if boxes is not None else None
123
+ obj_ids = obj_ids[sorted_idxs] if obj_ids is not None else None
124
+ obj_scores = obj_scores[sorted_idxs] if obj_scores is not None else None
125
+ attr_ids = attr_ids[sorted_idxs] if attr_ids is not None else None
126
+ attr_scores = attr_scores[sorted_idxs] if attr_scores is not None else None
127
+
128
+ assigned_colors = [self._random_color(maximum=1) for _ in range(len(boxes))]
129
+ assigned_colors = [assigned_colors[idx] for idx in sorted_idxs]
130
+ if obj_ids is not None:
131
+ labels = self._create_text_labels_attr(obj_ids, obj_scores, attr_ids, attr_scores)
132
+ for i in range(len(boxes)):
133
+ color = assigned_colors[i]
134
+ self.add_box(boxes[i], color)
135
+ self.draw_labels(labels[i], boxes[i], color)
136
+
137
+ def draw_labels(self, label, box, color):
138
+ x0, y0, x1, y1 = box
139
+ text_pos = (x0, y0)
140
+ instance_area = (y1 - y0) * (x1 - x0)
141
+ small = _SMALL_OBJ * self.scale
142
+ if instance_area < small or y1 - y0 < 40 * self.scale:
143
+ if y1 >= self.height - 5:
144
+ text_pos = (x1, y0)
145
+ else:
146
+ text_pos = (x0, y1)
147
+
148
+ height_ratio = (y1 - y0) / np.sqrt(self.height * self.width)
149
+ lighter_color = self._change_color_brightness(color, brightness_factor=0.7)
150
+ font_size = np.clip((height_ratio - 0.02) / 0.08 + 1, 1.2, 2)
151
+ font_size *= 0.75 * self.font_size
152
+
153
+ self.draw_text(
154
+ text=label,
155
+ position=text_pos,
156
+ color=lighter_color,
157
+ )
158
+
159
+ def draw_text(
160
+ self,
161
+ text,
162
+ position,
163
+ color="g",
164
+ ha="left",
165
+ ):
166
+ rotation = 0
167
+ font_size = self.font_size
168
+ color = np.maximum(list(mplc.to_rgb(color)), 0.2)
169
+ color[np.argmax(color)] = max(0.8, np.max(color))
170
+ bbox = {
171
+ "facecolor": "black",
172
+ "alpha": self.alpha,
173
+ "pad": self.pad,
174
+ "edgecolor": "none",
175
+ }
176
+ x, y = position
177
+ self.ax.text(
178
+ x,
179
+ y,
180
+ text,
181
+ size=font_size * self.scale,
182
+ family="sans-serif",
183
+ bbox=bbox,
184
+ verticalalignment="top",
185
+ horizontalalignment=ha,
186
+ color=color,
187
+ zorder=10,
188
+ rotation=rotation,
189
+ )
190
+
191
+ def save(self, saveas=None):
192
+ if saveas is None:
193
+ saveas = self.saveas
194
+ if saveas.lower().endswith(".jpg") or saveas.lower().endswith(".png"):
195
+ cv2.imwrite(
196
+ saveas,
197
+ self._get_buffer()[:, :, ::-1],
198
+ )
199
+ else:
200
+ self.fig.savefig(saveas)
201
+
202
+ def _create_text_labels_attr(self, classes, scores, attr_classes, attr_scores):
203
+ labels = [self.id2obj[i] for i in classes]
204
+ attr_labels = [self.id2attr[i] for i in attr_classes]
205
+ labels = [
206
+ f"{label} {score:.2f} {attr} {attr_score:.2f}"
207
+ for label, score, attr, attr_score in zip(labels, scores, attr_labels, attr_scores)
208
+ ]
209
+ return labels
210
+
211
+ def _create_text_labels(self, classes, scores):
212
+ labels = [self.id2obj[i] for i in classes]
213
+ if scores is not None:
214
+ if labels is None:
215
+ labels = ["{:.0f}%".format(s * 100) for s in scores]
216
+ else:
217
+ labels = ["{} {:.0f}%".format(li, s * 100) for li, s in zip(labels, scores)]
218
+ return labels
219
+
220
+ def _random_color(self, maximum=255):
221
+ idx = np.random.randint(0, len(_COLORS))
222
+ ret = _COLORS[idx] * maximum
223
+ if not self.rgb:
224
+ ret = ret[::-1]
225
+ return ret
226
+
227
+ def _get_buffer(self):
228
+ if not self.pynb:
229
+ s, (width, height) = self.canvas.print_to_buffer()
230
+ if (width, height) != (self.width, self.height):
231
+ img = cv2.resize(self.img, (width, height))
232
+ else:
233
+ img = self.img
234
+ else:
235
+ buf = io.BytesIO() # works for cairo backend
236
+ self.canvas.print_rgba(buf)
237
+ width, height = self.width, self.height
238
+ s = buf.getvalue()
239
+ img = self.img
240
+
241
+ buffer = np.frombuffer(s, dtype="uint8")
242
+ img_rgba = buffer.reshape(height, width, 4)
243
+ rgb, alpha = np.split(img_rgba, [3], axis=2)
244
+
245
+ try:
246
+ import numexpr as ne # fuse them with numexpr
247
+
248
+ visualized_image = ne.evaluate("img * (1 - alpha / 255.0) + rgb * (alpha / 255.0)")
249
+ except ImportError:
250
+ alpha = alpha.astype("float32") / 255.0
251
+ visualized_image = img * (1 - alpha) + rgb * alpha
252
+
253
+ return visualized_image.astype("uint8")
254
+
255
+ def _change_color_brightness(self, color, brightness_factor):
256
+ assert brightness_factor >= -1.0 and brightness_factor <= 1.0
257
+ color = mplc.to_rgb(color)
258
+ polygon_color = colorsys.rgb_to_hls(*mplc.to_rgb(color))
259
+ modified_lightness = polygon_color[1] + (brightness_factor * polygon_color[1])
260
+ modified_lightness = 0.0 if modified_lightness < 0.0 else modified_lightness
261
+ modified_lightness = 1.0 if modified_lightness > 1.0 else modified_lightness
262
+ modified_color = colorsys.hls_to_rgb(polygon_color[0], modified_lightness, polygon_color[2])
263
+ return modified_color
264
+
265
+
266
+ # Color map
267
+ _COLORS = (
268
+ np.array(
269
+ [
270
+ 0.000,
271
+ 0.447,
272
+ 0.741,
273
+ 0.850,
274
+ 0.325,
275
+ 0.098,
276
+ 0.929,
277
+ 0.694,
278
+ 0.125,
279
+ 0.494,
280
+ 0.184,
281
+ 0.556,
282
+ 0.466,
283
+ 0.674,
284
+ 0.188,
285
+ 0.301,
286
+ 0.745,
287
+ 0.933,
288
+ 0.635,
289
+ 0.078,
290
+ 0.184,
291
+ 0.300,
292
+ 0.300,
293
+ 0.300,
294
+ 0.600,
295
+ 0.600,
296
+ 0.600,
297
+ 1.000,
298
+ 0.000,
299
+ 0.000,
300
+ 1.000,
301
+ 0.500,
302
+ 0.000,
303
+ 0.749,
304
+ 0.749,
305
+ 0.000,
306
+ 0.000,
307
+ 1.000,
308
+ 0.000,
309
+ 0.000,
310
+ 0.000,
311
+ 1.000,
312
+ 0.667,
313
+ 0.000,
314
+ 1.000,
315
+ 0.333,
316
+ 0.333,
317
+ 0.000,
318
+ 0.333,
319
+ 0.667,
320
+ 0.000,
321
+ 0.333,
322
+ 1.000,
323
+ 0.000,
324
+ 0.667,
325
+ 0.333,
326
+ 0.000,
327
+ 0.667,
328
+ 0.667,
329
+ 0.000,
330
+ 0.667,
331
+ 1.000,
332
+ 0.000,
333
+ 1.000,
334
+ 0.333,
335
+ 0.000,
336
+ 1.000,
337
+ 0.667,
338
+ 0.000,
339
+ 1.000,
340
+ 1.000,
341
+ 0.000,
342
+ 0.000,
343
+ 0.333,
344
+ 0.500,
345
+ 0.000,
346
+ 0.667,
347
+ 0.500,
348
+ 0.000,
349
+ 1.000,
350
+ 0.500,
351
+ 0.333,
352
+ 0.000,
353
+ 0.500,
354
+ 0.333,
355
+ 0.333,
356
+ 0.500,
357
+ 0.333,
358
+ 0.667,
359
+ 0.500,
360
+ 0.333,
361
+ 1.000,
362
+ 0.500,
363
+ 0.667,
364
+ 0.000,
365
+ 0.500,
366
+ 0.667,
367
+ 0.333,
368
+ 0.500,
369
+ 0.667,
370
+ 0.667,
371
+ 0.500,
372
+ 0.667,
373
+ 1.000,
374
+ 0.500,
375
+ 1.000,
376
+ 0.000,
377
+ 0.500,
378
+ 1.000,
379
+ 0.333,
380
+ 0.500,
381
+ 1.000,
382
+ 0.667,
383
+ 0.500,
384
+ 1.000,
385
+ 1.000,
386
+ 0.500,
387
+ 0.000,
388
+ 0.333,
389
+ 1.000,
390
+ 0.000,
391
+ 0.667,
392
+ 1.000,
393
+ 0.000,
394
+ 1.000,
395
+ 1.000,
396
+ 0.333,
397
+ 0.000,
398
+ 1.000,
399
+ 0.333,
400
+ 0.333,
401
+ 1.000,
402
+ 0.333,
403
+ 0.667,
404
+ 1.000,
405
+ 0.333,
406
+ 1.000,
407
+ 1.000,
408
+ 0.667,
409
+ 0.000,
410
+ 1.000,
411
+ 0.667,
412
+ 0.333,
413
+ 1.000,
414
+ 0.667,
415
+ 0.667,
416
+ 1.000,
417
+ 0.667,
418
+ 1.000,
419
+ 1.000,
420
+ 1.000,
421
+ 0.000,
422
+ 1.000,
423
+ 1.000,
424
+ 0.333,
425
+ 1.000,
426
+ 1.000,
427
+ 0.667,
428
+ 1.000,
429
+ 0.333,
430
+ 0.000,
431
+ 0.000,
432
+ 0.500,
433
+ 0.000,
434
+ 0.000,
435
+ 0.667,
436
+ 0.000,
437
+ 0.000,
438
+ 0.833,
439
+ 0.000,
440
+ 0.000,
441
+ 1.000,
442
+ 0.000,
443
+ 0.000,
444
+ 0.000,
445
+ 0.167,
446
+ 0.000,
447
+ 0.000,
448
+ 0.333,
449
+ 0.000,
450
+ 0.000,
451
+ 0.500,
452
+ 0.000,
453
+ 0.000,
454
+ 0.667,
455
+ 0.000,
456
+ 0.000,
457
+ 0.833,
458
+ 0.000,
459
+ 0.000,
460
+ 1.000,
461
+ 0.000,
462
+ 0.000,
463
+ 0.000,
464
+ 0.167,
465
+ 0.000,
466
+ 0.000,
467
+ 0.333,
468
+ 0.000,
469
+ 0.000,
470
+ 0.500,
471
+ 0.000,
472
+ 0.000,
473
+ 0.667,
474
+ 0.000,
475
+ 0.000,
476
+ 0.833,
477
+ 0.000,
478
+ 0.000,
479
+ 1.000,
480
+ 0.000,
481
+ 0.000,
482
+ 0.000,
483
+ 0.143,
484
+ 0.143,
485
+ 0.143,
486
+ 0.857,
487
+ 0.857,
488
+ 0.857,
489
+ 1.000,
490
+ 1.000,
491
+ 1.000,
492
+ ]
493
+ )
494
+ .astype(np.float32)
495
+ .reshape(-1, 3)
496
+ )