saritha5 commited on
Commit
6c958fb
1 Parent(s): 80640e2

Upload visualization_utils.py

Browse files
Files changed (1) hide show
  1. visualization_utils.py +1353 -0
visualization_utils.py ADDED
@@ -0,0 +1,1353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """A set of functions that are used for visualization.
17
+
18
+ These functions often receive an image, perform some visualization on the image.
19
+ The functions do not return a value, instead they modify the image itself.
20
+
21
+ """
22
+ from __future__ import absolute_import
23
+ from __future__ import division
24
+ from __future__ import print_function
25
+
26
+
27
+ import abc
28
+ import collections
29
+ # Set headless-friendly backend.
30
+ import matplotlib; matplotlib.use('Agg') # pylint: disable=multiple-statements
31
+ import matplotlib.pyplot as plt # pylint: disable=g-import-not-at-top
32
+ import numpy as np
33
+ import PIL.Image as Image
34
+ import PIL.ImageColor as ImageColor
35
+ import PIL.ImageDraw as ImageDraw
36
+ import PIL.ImageFont as ImageFont
37
+ import six
38
+ from six.moves import range
39
+ from six.moves import zip
40
+ import tensorflow as tf
41
+
42
+ import keypoint_ops
43
+ import standard_fields as fields
44
+ import shape_utils
45
+
46
+ _TITLE_LEFT_MARGIN = 10
47
+ _TITLE_TOP_MARGIN = 10
48
+ STANDARD_COLORS = [
49
+ 'AliceBlue', 'Chartreuse', 'Aqua', 'Aquamarine', 'Azure', 'Beige', 'Bisque',
50
+ 'BlanchedAlmond', 'BlueViolet', 'BurlyWood', 'CadetBlue', 'AntiqueWhite',
51
+ 'Chocolate', 'Coral', 'CornflowerBlue', 'Cornsilk', 'Crimson', 'Cyan',
52
+ 'DarkCyan', 'DarkGoldenRod', 'DarkGrey', 'DarkKhaki', 'DarkOrange',
53
+ 'DarkOrchid', 'DarkSalmon', 'DarkSeaGreen', 'DarkTurquoise', 'DarkViolet',
54
+ 'DeepPink', 'DeepSkyBlue', 'DodgerBlue', 'FireBrick', 'FloralWhite',
55
+ 'ForestGreen', 'Fuchsia', 'Gainsboro', 'GhostWhite', 'Gold', 'GoldenRod',
56
+ 'Salmon', 'Tan', 'HoneyDew', 'HotPink', 'IndianRed', 'Ivory', 'Khaki',
57
+ 'Lavender', 'LavenderBlush', 'LawnGreen', 'LemonChiffon', 'LightBlue',
58
+ 'LightCoral', 'LightCyan', 'LightGoldenRodYellow', 'LightGray', 'LightGrey',
59
+ 'LightGreen', 'LightPink', 'LightSalmon', 'LightSeaGreen', 'LightSkyBlue',
60
+ 'LightSlateGray', 'LightSlateGrey', 'LightSteelBlue', 'LightYellow', 'Lime',
61
+ 'LimeGreen', 'Linen', 'Magenta', 'MediumAquaMarine', 'MediumOrchid',
62
+ 'MediumPurple', 'MediumSeaGreen', 'MediumSlateBlue', 'MediumSpringGreen',
63
+ 'MediumTurquoise', 'MediumVioletRed', 'MintCream', 'MistyRose', 'Moccasin',
64
+ 'NavajoWhite', 'OldLace', 'Olive', 'OliveDrab', 'Orange', 'OrangeRed',
65
+ 'Orchid', 'PaleGoldenRod', 'PaleGreen', 'PaleTurquoise', 'PaleVioletRed',
66
+ 'PapayaWhip', 'PeachPuff', 'Peru', 'Pink', 'Plum', 'PowderBlue', 'Purple',
67
+ 'Red', 'RosyBrown', 'RoyalBlue', 'SaddleBrown', 'Green', 'SandyBrown',
68
+ 'SeaGreen', 'SeaShell', 'Sienna', 'Silver', 'SkyBlue', 'SlateBlue',
69
+ 'SlateGray', 'SlateGrey', 'Snow', 'SpringGreen', 'SteelBlue', 'GreenYellow',
70
+ 'Teal', 'Thistle', 'Tomato', 'Turquoise', 'Violet', 'Wheat', 'White',
71
+ 'WhiteSmoke', 'Yellow', 'YellowGreen'
72
+ ]
73
+
74
+
75
+ def _get_multiplier_for_color_randomness():
76
+ """Returns a multiplier to get semi-random colors from successive indices.
77
+
78
+ This function computes a prime number, p, in the range [2, 17] that:
79
+ - is closest to len(STANDARD_COLORS) / 10
80
+ - does not divide len(STANDARD_COLORS)
81
+
82
+ If no prime numbers in that range satisfy the constraints, p is returned as 1.
83
+
84
+ Once p is established, it can be used as a multiplier to select
85
+ non-consecutive colors from STANDARD_COLORS:
86
+ colors = [(p * i) % len(STANDARD_COLORS) for i in range(20)]
87
+ """
88
+ num_colors = len(STANDARD_COLORS)
89
+ prime_candidates = [5, 7, 11, 13, 17]
90
+
91
+ # Remove all prime candidates that divide the number of colors.
92
+ prime_candidates = [p for p in prime_candidates if num_colors % p]
93
+ if not prime_candidates:
94
+ return 1
95
+
96
+ # Return the closest prime number to num_colors / 10.
97
+ abs_distance = [np.abs(num_colors / 10. - p) for p in prime_candidates]
98
+ num_candidates = len(abs_distance)
99
+ inds = [i for _, i in sorted(zip(abs_distance, range(num_candidates)))]
100
+ return prime_candidates[inds[0]]
101
+
102
+
103
+ def save_image_array_as_png(image, output_path):
104
+ """Saves an image (represented as a numpy array) to PNG.
105
+
106
+ Args:
107
+ image: a numpy array with shape [height, width, 3].
108
+ output_path: path to which image should be written.
109
+ """
110
+ image_pil = Image.fromarray(np.uint8(image)).convert('RGB')
111
+ with tf.gfile.Open(output_path, 'w') as fid:
112
+ image_pil.save(fid, 'PNG')
113
+
114
+
115
+ def encode_image_array_as_png_str(image):
116
+ """Encodes a numpy array into a PNG string.
117
+
118
+ Args:
119
+ image: a numpy array with shape [height, width, 3].
120
+
121
+ Returns:
122
+ PNG encoded image string.
123
+ """
124
+ image_pil = Image.fromarray(np.uint8(image))
125
+ output = six.BytesIO()
126
+ image_pil.save(output, format='PNG')
127
+ png_string = output.getvalue()
128
+ output.close()
129
+ return png_string
130
+
131
+
132
+ def draw_bounding_box_on_image_array(image,
133
+ ymin,
134
+ xmin,
135
+ ymax,
136
+ xmax,
137
+ color='red',
138
+ thickness=4,
139
+ display_str_list=(),
140
+ use_normalized_coordinates=True):
141
+ """Adds a bounding box to an image (numpy array).
142
+
143
+ Bounding box coordinates can be specified in either absolute (pixel) or
144
+ normalized coordinates by setting the use_normalized_coordinates argument.
145
+
146
+ Args:
147
+ image: a numpy array with shape [height, width, 3].
148
+ ymin: ymin of bounding box.
149
+ xmin: xmin of bounding box.
150
+ ymax: ymax of bounding box.
151
+ xmax: xmax of bounding box.
152
+ color: color to draw bounding box. Default is red.
153
+ thickness: line thickness. Default value is 4.
154
+ display_str_list: list of strings to display in box
155
+ (each to be shown on its own line).
156
+ use_normalized_coordinates: If True (default), treat coordinates
157
+ ymin, xmin, ymax, xmax as relative to the image. Otherwise treat
158
+ coordinates as absolute.
159
+ """
160
+ image_pil = Image.fromarray(np.uint8(image)).convert('RGB')
161
+ draw_bounding_box_on_image(image_pil, ymin, xmin, ymax, xmax, color,
162
+ thickness, display_str_list,
163
+ use_normalized_coordinates)
164
+ np.copyto(image, np.array(image_pil))
165
+
166
+
167
+ def draw_bounding_box_on_image(image,
168
+ ymin,
169
+ xmin,
170
+ ymax,
171
+ xmax,
172
+ color='red',
173
+ thickness=4,
174
+ display_str_list=(),
175
+ use_normalized_coordinates=True):
176
+ """Adds a bounding box to an image.
177
+
178
+ Bounding box coordinates can be specified in either absolute (pixel) or
179
+ normalized coordinates by setting the use_normalized_coordinates argument.
180
+
181
+ Each string in display_str_list is displayed on a separate line above the
182
+ bounding box in black text on a rectangle filled with the input 'color'.
183
+ If the top of the bounding box extends to the edge of the image, the strings
184
+ are displayed below the bounding box.
185
+
186
+ Args:
187
+ image: a PIL.Image object.
188
+ ymin: ymin of bounding box.
189
+ xmin: xmin of bounding box.
190
+ ymax: ymax of bounding box.
191
+ xmax: xmax of bounding box.
192
+ color: color to draw bounding box. Default is red.
193
+ thickness: line thickness. Default value is 4.
194
+ display_str_list: list of strings to display in box
195
+ (each to be shown on its own line).
196
+ use_normalized_coordinates: If True (default), treat coordinates
197
+ ymin, xmin, ymax, xmax as relative to the image. Otherwise treat
198
+ coordinates as absolute.
199
+ """
200
+ draw = ImageDraw.Draw(image)
201
+ im_width, im_height = image.size
202
+ if use_normalized_coordinates:
203
+ (left, right, top, bottom) = (xmin * im_width, xmax * im_width,
204
+ ymin * im_height, ymax * im_height)
205
+ else:
206
+ (left, right, top, bottom) = (xmin, xmax, ymin, ymax)
207
+ if thickness > 0:
208
+ draw.line([(left, top), (left, bottom), (right, bottom), (right, top),
209
+ (left, top)],
210
+ width=thickness,
211
+ fill=color)
212
+ try:
213
+ font = ImageFont.truetype('arial.ttf', 24)
214
+ except IOError:
215
+ font = ImageFont.load_default()
216
+
217
+ # If the total height of the display strings added to the top of the bounding
218
+ # box exceeds the top of the image, stack the strings below the bounding box
219
+ # instead of above.
220
+ display_str_heights = [font.getsize(ds)[1] for ds in display_str_list]
221
+ # Each display_str has a top and bottom margin of 0.05x.
222
+ total_display_str_height = (1 + 2 * 0.05) * sum(display_str_heights)
223
+
224
+ if top > total_display_str_height:
225
+ text_bottom = top
226
+ else:
227
+ text_bottom = bottom + total_display_str_height
228
+ # Reverse list and print from bottom to top.
229
+ for display_str in display_str_list[::-1]:
230
+ text_width, text_height = font.getsize(display_str)
231
+ margin = np.ceil(0.05 * text_height)
232
+ draw.rectangle(
233
+ [(left, text_bottom - text_height - 2 * margin), (left + text_width,
234
+ text_bottom)],
235
+ fill=color)
236
+ draw.text(
237
+ (left + margin, text_bottom - text_height - margin),
238
+ display_str,
239
+ fill='black',
240
+ font=font)
241
+ text_bottom -= text_height - 2 * margin
242
+
243
+
244
+ def draw_bounding_boxes_on_image_array(image,
245
+ boxes,
246
+ color='red',
247
+ thickness=4,
248
+ display_str_list_list=()):
249
+ """Draws bounding boxes on image (numpy array).
250
+
251
+ Args:
252
+ image: a numpy array object.
253
+ boxes: a 2 dimensional numpy array of [N, 4]: (ymin, xmin, ymax, xmax).
254
+ The coordinates are in normalized format between [0, 1].
255
+ color: color to draw bounding box. Default is red.
256
+ thickness: line thickness. Default value is 4.
257
+ display_str_list_list: list of list of strings.
258
+ a list of strings for each bounding box.
259
+ The reason to pass a list of strings for a
260
+ bounding box is that it might contain
261
+ multiple labels.
262
+
263
+ Raises:
264
+ ValueError: if boxes is not a [N, 4] array
265
+ """
266
+ image_pil = Image.fromarray(image)
267
+ draw_bounding_boxes_on_image(image_pil, boxes, color, thickness,
268
+ display_str_list_list)
269
+ np.copyto(image, np.array(image_pil))
270
+
271
+
272
+ def draw_bounding_boxes_on_image(image,
273
+ boxes,
274
+ color='red',
275
+ thickness=4,
276
+ display_str_list_list=()):
277
+ """Draws bounding boxes on image.
278
+
279
+ Args:
280
+ image: a PIL.Image object.
281
+ boxes: a 2 dimensional numpy array of [N, 4]: (ymin, xmin, ymax, xmax).
282
+ The coordinates are in normalized format between [0, 1].
283
+ color: color to draw bounding box. Default is red.
284
+ thickness: line thickness. Default value is 4.
285
+ display_str_list_list: list of list of strings.
286
+ a list of strings for each bounding box.
287
+ The reason to pass a list of strings for a
288
+ bounding box is that it might contain
289
+ multiple labels.
290
+
291
+ Raises:
292
+ ValueError: if boxes is not a [N, 4] array
293
+ """
294
+ boxes_shape = boxes.shape
295
+ if not boxes_shape:
296
+ return
297
+ if len(boxes_shape) != 2 or boxes_shape[1] != 4:
298
+ raise ValueError('Input must be of size [N, 4]')
299
+ for i in range(boxes_shape[0]):
300
+ display_str_list = ()
301
+ if display_str_list_list:
302
+ display_str_list = display_str_list_list[i]
303
+ draw_bounding_box_on_image(image, boxes[i, 0], boxes[i, 1], boxes[i, 2],
304
+ boxes[i, 3], color, thickness, display_str_list)
305
+
306
+
307
+ def create_visualization_fn(category_index,
308
+ include_masks=False,
309
+ include_keypoints=False,
310
+ include_keypoint_scores=False,
311
+ include_track_ids=False,
312
+ **kwargs):
313
+ """Constructs a visualization function that can be wrapped in a py_func.
314
+
315
+ py_funcs only accept positional arguments. This function returns a suitable
316
+ function with the correct positional argument mapping. The positional
317
+ arguments in order are:
318
+ 0: image
319
+ 1: boxes
320
+ 2: classes
321
+ 3: scores
322
+ [4]: masks (optional)
323
+ [4-5]: keypoints (optional)
324
+ [4-6]: keypoint_scores (optional)
325
+ [4-7]: track_ids (optional)
326
+
327
+ -- Example 1 --
328
+ vis_only_masks_fn = create_visualization_fn(category_index,
329
+ include_masks=True, include_keypoints=False, include_track_ids=False,
330
+ **kwargs)
331
+ image = tf.py_func(vis_only_masks_fn,
332
+ inp=[image, boxes, classes, scores, masks],
333
+ Tout=tf.uint8)
334
+
335
+ -- Example 2 --
336
+ vis_masks_and_track_ids_fn = create_visualization_fn(category_index,
337
+ include_masks=True, include_keypoints=False, include_track_ids=True,
338
+ **kwargs)
339
+ image = tf.py_func(vis_masks_and_track_ids_fn,
340
+ inp=[image, boxes, classes, scores, masks, track_ids],
341
+ Tout=tf.uint8)
342
+
343
+ Args:
344
+ category_index: a dict that maps integer ids to category dicts. e.g.
345
+ {1: {1: 'dog'}, 2: {2: 'cat'}, ...}
346
+ include_masks: Whether masks should be expected as a positional argument in
347
+ the returned function.
348
+ include_keypoints: Whether keypoints should be expected as a positional
349
+ argument in the returned function.
350
+ include_keypoint_scores: Whether keypoint scores should be expected as a
351
+ positional argument in the returned function.
352
+ include_track_ids: Whether track ids should be expected as a positional
353
+ argument in the returned function.
354
+ **kwargs: Additional kwargs that will be passed to
355
+ visualize_boxes_and_labels_on_image_array.
356
+
357
+ Returns:
358
+ Returns a function that only takes tensors as positional arguments.
359
+ """
360
+
361
+ def visualization_py_func_fn(*args):
362
+ """Visualization function that can be wrapped in a tf.py_func.
363
+
364
+ Args:
365
+ *args: First 4 positional arguments must be:
366
+ image - uint8 numpy array with shape (img_height, img_width, 3).
367
+ boxes - a numpy array of shape [N, 4].
368
+ classes - a numpy array of shape [N].
369
+ scores - a numpy array of shape [N] or None.
370
+ -- Optional positional arguments --
371
+ instance_masks - a numpy array of shape [N, image_height, image_width].
372
+ keypoints - a numpy array of shape [N, num_keypoints, 2].
373
+ keypoint_scores - a numpy array of shape [N, num_keypoints].
374
+ track_ids - a numpy array of shape [N] with unique track ids.
375
+
376
+ Returns:
377
+ uint8 numpy array with shape (img_height, img_width, 3) with overlaid
378
+ boxes.
379
+ """
380
+ image = args[0]
381
+ boxes = args[1]
382
+ classes = args[2]
383
+ scores = args[3]
384
+ masks = keypoints = keypoint_scores = track_ids = None
385
+ pos_arg_ptr = 4 # Positional argument for first optional tensor (masks).
386
+ if include_masks:
387
+ masks = args[pos_arg_ptr]
388
+ pos_arg_ptr += 1
389
+ if include_keypoints:
390
+ keypoints = args[pos_arg_ptr]
391
+ pos_arg_ptr += 1
392
+ if include_keypoint_scores:
393
+ keypoint_scores = args[pos_arg_ptr]
394
+ pos_arg_ptr += 1
395
+ if include_track_ids:
396
+ track_ids = args[pos_arg_ptr]
397
+
398
+ return visualize_boxes_and_labels_on_image_array(
399
+ image,
400
+ boxes,
401
+ classes,
402
+ scores,
403
+ category_index=category_index,
404
+ instance_masks=masks,
405
+ keypoints=keypoints,
406
+ keypoint_scores=keypoint_scores,
407
+ track_ids=track_ids,
408
+ **kwargs)
409
+ return visualization_py_func_fn
410
+
411
+
412
+ def draw_heatmaps_on_image(image, heatmaps):
413
+ """Draws heatmaps on an image.
414
+
415
+ The heatmaps are handled channel by channel and different colors are used to
416
+ paint different heatmap channels.
417
+
418
+ Args:
419
+ image: a PIL.Image object.
420
+ heatmaps: a numpy array with shape [image_height, image_width, channel].
421
+ Note that the image_height and image_width should match the size of input
422
+ image.
423
+ """
424
+ draw = ImageDraw.Draw(image)
425
+ channel = heatmaps.shape[2]
426
+ for c in range(channel):
427
+ heatmap = heatmaps[:, :, c] * 255
428
+ heatmap = heatmap.astype('uint8')
429
+ bitmap = Image.fromarray(heatmap, 'L')
430
+ bitmap.convert('1')
431
+ draw.bitmap(
432
+ xy=[(0, 0)],
433
+ bitmap=bitmap,
434
+ fill=STANDARD_COLORS[c])
435
+
436
+
437
+ def draw_heatmaps_on_image_array(image, heatmaps):
438
+ """Overlays heatmaps to an image (numpy array).
439
+
440
+ The function overlays the heatmaps on top of image. The heatmap values will be
441
+ painted with different colors depending on the channels. Similar to
442
+ "draw_heatmaps_on_image_array" function except the inputs are numpy arrays.
443
+
444
+ Args:
445
+ image: a numpy array with shape [height, width, 3].
446
+ heatmaps: a numpy array with shape [height, width, channel].
447
+
448
+ Returns:
449
+ An uint8 numpy array representing the input image painted with heatmap
450
+ colors.
451
+ """
452
+ if not isinstance(image, np.ndarray):
453
+ image = image.numpy()
454
+ if not isinstance(heatmaps, np.ndarray):
455
+ heatmaps = heatmaps.numpy()
456
+ image_pil = Image.fromarray(np.uint8(image)).convert('RGB')
457
+ draw_heatmaps_on_image(image_pil, heatmaps)
458
+ return np.array(image_pil)
459
+
460
+
461
+ def draw_heatmaps_on_image_tensors(images,
462
+ heatmaps,
463
+ apply_sigmoid=False):
464
+ """Draws heatmaps on batch of image tensors.
465
+
466
+ Args:
467
+ images: A 4D uint8 image tensor of shape [N, H, W, C]. If C > 3, additional
468
+ channels will be ignored. If C = 1, then we convert the images to RGB
469
+ images.
470
+ heatmaps: [N, h, w, channel] float32 tensor of heatmaps. Note that the
471
+ heatmaps will be resized to match the input image size before overlaying
472
+ the heatmaps with input images. Theoretically the heatmap height width
473
+ should have the same aspect ratio as the input image to avoid potential
474
+ misalignment introduced by the image resize.
475
+ apply_sigmoid: Whether to apply a sigmoid layer on top of the heatmaps. If
476
+ the heatmaps come directly from the prediction logits, then we should
477
+ apply the sigmoid layer to make sure the values are in between [0.0, 1.0].
478
+
479
+ Returns:
480
+ 4D image tensor of type uint8, with heatmaps overlaid on top.
481
+ """
482
+ # Additional channels are being ignored.
483
+ if images.shape[3] > 3:
484
+ images = images[:, :, :, 0:3]
485
+ elif images.shape[3] == 1:
486
+ images = tf.image.grayscale_to_rgb(images)
487
+
488
+ _, height, width, _ = shape_utils.combined_static_and_dynamic_shape(images)
489
+ if apply_sigmoid:
490
+ heatmaps = tf.math.sigmoid(heatmaps)
491
+ resized_heatmaps = tf.image.resize(heatmaps, size=[height, width])
492
+
493
+ elems = [images, resized_heatmaps]
494
+
495
+ def draw_heatmaps(image_and_heatmaps):
496
+ """Draws heatmaps on image."""
497
+ image_with_heatmaps = tf.py_function(
498
+ draw_heatmaps_on_image_array,
499
+ image_and_heatmaps,
500
+ tf.uint8)
501
+ return image_with_heatmaps
502
+ images = tf.map_fn(draw_heatmaps, elems, dtype=tf.uint8, back_prop=False)
503
+ return images
504
+
505
+
506
+ def _resize_original_image(image, image_shape):
507
+ image = tf.expand_dims(image, 0)
508
+ image = tf.image.resize_images(
509
+ image,
510
+ image_shape,
511
+ method=tf.image.ResizeMethod.NEAREST_NEIGHBOR,
512
+ align_corners=True)
513
+ return tf.cast(tf.squeeze(image, 0), tf.uint8)
514
+
515
+
516
+ def draw_bounding_boxes_on_image_tensors(images,
517
+ boxes,
518
+ classes,
519
+ scores,
520
+ category_index,
521
+ original_image_spatial_shape=None,
522
+ true_image_shape=None,
523
+ instance_masks=None,
524
+ keypoints=None,
525
+ keypoint_scores=None,
526
+ keypoint_edges=None,
527
+ track_ids=None,
528
+ max_boxes_to_draw=20,
529
+ min_score_thresh=0.2,
530
+ use_normalized_coordinates=True):
531
+ """Draws bounding boxes, masks, and keypoints on batch of image tensors.
532
+
533
+ Args:
534
+ images: A 4D uint8 image tensor of shape [N, H, W, C]. If C > 3, additional
535
+ channels will be ignored. If C = 1, then we convert the images to RGB
536
+ images.
537
+ boxes: [N, max_detections, 4] float32 tensor of detection boxes.
538
+ classes: [N, max_detections] int tensor of detection classes. Note that
539
+ classes are 1-indexed.
540
+ scores: [N, max_detections] float32 tensor of detection scores.
541
+ category_index: a dict that maps integer ids to category dicts. e.g.
542
+ {1: {1: 'dog'}, 2: {2: 'cat'}, ...}
543
+ original_image_spatial_shape: [N, 2] tensor containing the spatial size of
544
+ the original image.
545
+ true_image_shape: [N, 3] tensor containing the spatial size of unpadded
546
+ original_image.
547
+ instance_masks: A 4D uint8 tensor of shape [N, max_detection, H, W] with
548
+ instance masks.
549
+ keypoints: A 4D float32 tensor of shape [N, max_detection, num_keypoints, 2]
550
+ with keypoints.
551
+ keypoint_scores: A 3D float32 tensor of shape [N, max_detection,
552
+ num_keypoints] with keypoint scores.
553
+ keypoint_edges: A list of tuples with keypoint indices that specify which
554
+ keypoints should be connected by an edge, e.g. [(0, 1), (2, 4)] draws
555
+ edges from keypoint 0 to 1 and from keypoint 2 to 4.
556
+ track_ids: [N, max_detections] int32 tensor of unique tracks ids (i.e.
557
+ instance ids for each object). If provided, the color-coding of boxes is
558
+ dictated by these ids, and not classes.
559
+ max_boxes_to_draw: Maximum number of boxes to draw on an image. Default 20.
560
+ min_score_thresh: Minimum score threshold for visualization. Default 0.2.
561
+ use_normalized_coordinates: Whether to assume boxes and kepoints are in
562
+ normalized coordinates (as opposed to absolute coordiantes).
563
+ Default is True.
564
+
565
+ Returns:
566
+ 4D image tensor of type uint8, with boxes drawn on top.
567
+ """
568
+ # Additional channels are being ignored.
569
+ if images.shape[3] > 3:
570
+ images = images[:, :, :, 0:3]
571
+ elif images.shape[3] == 1:
572
+ images = tf.image.grayscale_to_rgb(images)
573
+ visualization_keyword_args = {
574
+ 'use_normalized_coordinates': use_normalized_coordinates,
575
+ 'max_boxes_to_draw': max_boxes_to_draw,
576
+ 'min_score_thresh': min_score_thresh,
577
+ 'agnostic_mode': False,
578
+ 'line_thickness': 4,
579
+ 'keypoint_edges': keypoint_edges
580
+ }
581
+ if true_image_shape is None:
582
+ true_shapes = tf.constant(-1, shape=[images.shape.as_list()[0], 3])
583
+ else:
584
+ true_shapes = true_image_shape
585
+ if original_image_spatial_shape is None:
586
+ original_shapes = tf.constant(-1, shape=[images.shape.as_list()[0], 2])
587
+ else:
588
+ original_shapes = original_image_spatial_shape
589
+
590
+ visualize_boxes_fn = create_visualization_fn(
591
+ category_index,
592
+ include_masks=instance_masks is not None,
593
+ include_keypoints=keypoints is not None,
594
+ include_keypoint_scores=keypoint_scores is not None,
595
+ include_track_ids=track_ids is not None,
596
+ **visualization_keyword_args)
597
+
598
+ elems = [true_shapes, original_shapes, images, boxes, classes, scores]
599
+ if instance_masks is not None:
600
+ elems.append(instance_masks)
601
+ if keypoints is not None:
602
+ elems.append(keypoints)
603
+ if keypoint_scores is not None:
604
+ elems.append(keypoint_scores)
605
+ if track_ids is not None:
606
+ elems.append(track_ids)
607
+
608
+ def draw_boxes(image_and_detections):
609
+ """Draws boxes on image."""
610
+ true_shape = image_and_detections[0]
611
+ original_shape = image_and_detections[1]
612
+ if true_image_shape is not None:
613
+ image = shape_utils.pad_or_clip_nd(image_and_detections[2],
614
+ [true_shape[0], true_shape[1], 3])
615
+ if original_image_spatial_shape is not None:
616
+ image_and_detections[2] = _resize_original_image(image, original_shape)
617
+
618
+ image_with_boxes = tf.py_func(visualize_boxes_fn, image_and_detections[2:],
619
+ tf.uint8)
620
+ return image_with_boxes
621
+
622
+ images = tf.map_fn(draw_boxes, elems, dtype=tf.uint8, back_prop=False)
623
+ return images
624
+
625
+
626
+ def draw_side_by_side_evaluation_image(eval_dict,
627
+ category_index,
628
+ max_boxes_to_draw=20,
629
+ min_score_thresh=0.2,
630
+ use_normalized_coordinates=True,
631
+ keypoint_edges=None):
632
+ """Creates a side-by-side image with detections and groundtruth.
633
+
634
+ Bounding boxes (and instance masks, if available) are visualized on both
635
+ subimages.
636
+
637
+ Args:
638
+ eval_dict: The evaluation dictionary returned by
639
+ eval_util.result_dict_for_batched_example() or
640
+ eval_util.result_dict_for_single_example().
641
+ category_index: A category index (dictionary) produced from a labelmap.
642
+ max_boxes_to_draw: The maximum number of boxes to draw for detections.
643
+ min_score_thresh: The minimum score threshold for showing detections.
644
+ use_normalized_coordinates: Whether to assume boxes and keypoints are in
645
+ normalized coordinates (as opposed to absolute coordinates).
646
+ Default is True.
647
+ keypoint_edges: A list of tuples with keypoint indices that specify which
648
+ keypoints should be connected by an edge, e.g. [(0, 1), (2, 4)] draws
649
+ edges from keypoint 0 to 1 and from keypoint 2 to 4.
650
+
651
+ Returns:
652
+ A list of [1, H, 2 * W, C] uint8 tensor. The subimage on the left
653
+ corresponds to detections, while the subimage on the right corresponds to
654
+ groundtruth.
655
+ """
656
+ detection_fields = fields.DetectionResultFields()
657
+ input_data_fields = fields.InputDataFields()
658
+
659
+ images_with_detections_list = []
660
+
661
+ # Add the batch dimension if the eval_dict is for single example.
662
+ if len(eval_dict[detection_fields.detection_classes].shape) == 1:
663
+ for key in eval_dict:
664
+ if (key != input_data_fields.original_image and
665
+ key != input_data_fields.image_additional_channels):
666
+ eval_dict[key] = tf.expand_dims(eval_dict[key], 0)
667
+
668
+ for indx in range(eval_dict[input_data_fields.original_image].shape[0]):
669
+ instance_masks = None
670
+ if detection_fields.detection_masks in eval_dict:
671
+ instance_masks = tf.cast(
672
+ tf.expand_dims(
673
+ eval_dict[detection_fields.detection_masks][indx], axis=0),
674
+ tf.uint8)
675
+ keypoints = None
676
+ keypoint_scores = None
677
+ if detection_fields.detection_keypoints in eval_dict:
678
+ keypoints = tf.expand_dims(
679
+ eval_dict[detection_fields.detection_keypoints][indx], axis=0)
680
+ if detection_fields.detection_keypoint_scores in eval_dict:
681
+ keypoint_scores = tf.expand_dims(
682
+ eval_dict[detection_fields.detection_keypoint_scores][indx], axis=0)
683
+ else:
684
+ keypoint_scores = tf.cast(keypoint_ops.set_keypoint_visibilities(
685
+ keypoints), dtype=tf.float32)
686
+
687
+ groundtruth_instance_masks = None
688
+ if input_data_fields.groundtruth_instance_masks in eval_dict:
689
+ groundtruth_instance_masks = tf.cast(
690
+ tf.expand_dims(
691
+ eval_dict[input_data_fields.groundtruth_instance_masks][indx],
692
+ axis=0), tf.uint8)
693
+ groundtruth_keypoints = None
694
+ groundtruth_keypoint_scores = None
695
+ gt_kpt_vis_fld = input_data_fields.groundtruth_keypoint_visibilities
696
+ if input_data_fields.groundtruth_keypoints in eval_dict:
697
+ groundtruth_keypoints = tf.expand_dims(
698
+ eval_dict[input_data_fields.groundtruth_keypoints][indx], axis=0)
699
+ if gt_kpt_vis_fld in eval_dict:
700
+ groundtruth_keypoint_scores = tf.expand_dims(
701
+ tf.cast(eval_dict[gt_kpt_vis_fld][indx], dtype=tf.float32), axis=0)
702
+ else:
703
+ groundtruth_keypoint_scores = tf.cast(
704
+ keypoint_ops.set_keypoint_visibilities(
705
+ groundtruth_keypoints), dtype=tf.float32)
706
+
707
+ images_with_detections = draw_bounding_boxes_on_image_tensors(
708
+ tf.expand_dims(
709
+ eval_dict[input_data_fields.original_image][indx], axis=0),
710
+ tf.expand_dims(
711
+ eval_dict[detection_fields.detection_boxes][indx], axis=0),
712
+ tf.expand_dims(
713
+ eval_dict[detection_fields.detection_classes][indx], axis=0),
714
+ tf.expand_dims(
715
+ eval_dict[detection_fields.detection_scores][indx], axis=0),
716
+ category_index,
717
+ original_image_spatial_shape=tf.expand_dims(
718
+ eval_dict[input_data_fields.original_image_spatial_shape][indx],
719
+ axis=0),
720
+ true_image_shape=tf.expand_dims(
721
+ eval_dict[input_data_fields.true_image_shape][indx], axis=0),
722
+ instance_masks=instance_masks,
723
+ keypoints=keypoints,
724
+ keypoint_scores=keypoint_scores,
725
+ keypoint_edges=keypoint_edges,
726
+ max_boxes_to_draw=max_boxes_to_draw,
727
+ min_score_thresh=min_score_thresh,
728
+ use_normalized_coordinates=use_normalized_coordinates)
729
+ images_with_groundtruth = draw_bounding_boxes_on_image_tensors(
730
+ tf.expand_dims(
731
+ eval_dict[input_data_fields.original_image][indx], axis=0),
732
+ tf.expand_dims(
733
+ eval_dict[input_data_fields.groundtruth_boxes][indx], axis=0),
734
+ tf.expand_dims(
735
+ eval_dict[input_data_fields.groundtruth_classes][indx], axis=0),
736
+ tf.expand_dims(
737
+ tf.ones_like(
738
+ eval_dict[input_data_fields.groundtruth_classes][indx],
739
+ dtype=tf.float32),
740
+ axis=0),
741
+ category_index,
742
+ original_image_spatial_shape=tf.expand_dims(
743
+ eval_dict[input_data_fields.original_image_spatial_shape][indx],
744
+ axis=0),
745
+ true_image_shape=tf.expand_dims(
746
+ eval_dict[input_data_fields.true_image_shape][indx], axis=0),
747
+ instance_masks=groundtruth_instance_masks,
748
+ keypoints=groundtruth_keypoints,
749
+ keypoint_scores=groundtruth_keypoint_scores,
750
+ keypoint_edges=keypoint_edges,
751
+ max_boxes_to_draw=None,
752
+ min_score_thresh=0.0,
753
+ use_normalized_coordinates=use_normalized_coordinates)
754
+ images_to_visualize = tf.concat([images_with_detections,
755
+ images_with_groundtruth], axis=2)
756
+
757
+ if input_data_fields.image_additional_channels in eval_dict:
758
+ images_with_additional_channels_groundtruth = (
759
+ draw_bounding_boxes_on_image_tensors(
760
+ tf.expand_dims(
761
+ eval_dict[input_data_fields.image_additional_channels][indx],
762
+ axis=0),
763
+ tf.expand_dims(
764
+ eval_dict[input_data_fields.groundtruth_boxes][indx], axis=0),
765
+ tf.expand_dims(
766
+ eval_dict[input_data_fields.groundtruth_classes][indx],
767
+ axis=0),
768
+ tf.expand_dims(
769
+ tf.ones_like(
770
+ eval_dict[input_data_fields.groundtruth_classes][indx],
771
+ dtype=tf.float32),
772
+ axis=0),
773
+ category_index,
774
+ original_image_spatial_shape=tf.expand_dims(
775
+ eval_dict[input_data_fields.original_image_spatial_shape]
776
+ [indx],
777
+ axis=0),
778
+ true_image_shape=tf.expand_dims(
779
+ eval_dict[input_data_fields.true_image_shape][indx], axis=0),
780
+ instance_masks=groundtruth_instance_masks,
781
+ keypoints=None,
782
+ keypoint_edges=None,
783
+ max_boxes_to_draw=None,
784
+ min_score_thresh=0.0,
785
+ use_normalized_coordinates=use_normalized_coordinates))
786
+ images_to_visualize = tf.concat(
787
+ [images_to_visualize, images_with_additional_channels_groundtruth],
788
+ axis=2)
789
+ images_with_detections_list.append(images_to_visualize)
790
+
791
+ return images_with_detections_list
792
+
793
+
794
+ def draw_keypoints_on_image_array(image,
795
+ keypoints,
796
+ keypoint_scores=None,
797
+ min_score_thresh=0.5,
798
+ color='red',
799
+ radius=2,
800
+ use_normalized_coordinates=True,
801
+ keypoint_edges=None,
802
+ keypoint_edge_color='green',
803
+ keypoint_edge_width=2):
804
+ """Draws keypoints on an image (numpy array).
805
+
806
+ Args:
807
+ image: a numpy array with shape [height, width, 3].
808
+ keypoints: a numpy array with shape [num_keypoints, 2].
809
+ keypoint_scores: a numpy array with shape [num_keypoints]. If provided, only
810
+ those keypoints with a score above score_threshold will be visualized.
811
+ min_score_thresh: A scalar indicating the minimum keypoint score required
812
+ for a keypoint to be visualized. Note that keypoint_scores must be
813
+ provided for this threshold to take effect.
814
+ color: color to draw the keypoints with. Default is red.
815
+ radius: keypoint radius. Default value is 2.
816
+ use_normalized_coordinates: if True (default), treat keypoint values as
817
+ relative to the image. Otherwise treat them as absolute.
818
+ keypoint_edges: A list of tuples with keypoint indices that specify which
819
+ keypoints should be connected by an edge, e.g. [(0, 1), (2, 4)] draws
820
+ edges from keypoint 0 to 1 and from keypoint 2 to 4.
821
+ keypoint_edge_color: color to draw the keypoint edges with. Default is red.
822
+ keypoint_edge_width: width of the edges drawn between keypoints. Default
823
+ value is 2.
824
+ """
825
+ image_pil = Image.fromarray(np.uint8(image)).convert('RGB')
826
+ draw_keypoints_on_image(image_pil,
827
+ keypoints,
828
+ keypoint_scores=keypoint_scores,
829
+ min_score_thresh=min_score_thresh,
830
+ color=color,
831
+ radius=radius,
832
+ use_normalized_coordinates=use_normalized_coordinates,
833
+ keypoint_edges=keypoint_edges,
834
+ keypoint_edge_color=keypoint_edge_color,
835
+ keypoint_edge_width=keypoint_edge_width)
836
+ np.copyto(image, np.array(image_pil))
837
+
838
+
839
+ def draw_keypoints_on_image(image,
840
+ keypoints,
841
+ keypoint_scores=None,
842
+ min_score_thresh=0.5,
843
+ color='red',
844
+ radius=2,
845
+ use_normalized_coordinates=True,
846
+ keypoint_edges=None,
847
+ keypoint_edge_color='green',
848
+ keypoint_edge_width=2):
849
+ """Draws keypoints on an image.
850
+
851
+ Args:
852
+ image: a PIL.Image object.
853
+ keypoints: a numpy array with shape [num_keypoints, 2].
854
+ keypoint_scores: a numpy array with shape [num_keypoints].
855
+ min_score_thresh: a score threshold for visualizing keypoints. Only used if
856
+ keypoint_scores is provided.
857
+ color: color to draw the keypoints with. Default is red.
858
+ radius: keypoint radius. Default value is 2.
859
+ use_normalized_coordinates: if True (default), treat keypoint values as
860
+ relative to the image. Otherwise treat them as absolute.
861
+ keypoint_edges: A list of tuples with keypoint indices that specify which
862
+ keypoints should be connected by an edge, e.g. [(0, 1), (2, 4)] draws
863
+ edges from keypoint 0 to 1 and from keypoint 2 to 4.
864
+ keypoint_edge_color: color to draw the keypoint edges with. Default is red.
865
+ keypoint_edge_width: width of the edges drawn between keypoints. Default
866
+ value is 2.
867
+ """
868
+ draw = ImageDraw.Draw(image)
869
+ im_width, im_height = image.size
870
+ keypoints = np.array(keypoints)
871
+ keypoints_x = [k[1] for k in keypoints]
872
+ keypoints_y = [k[0] for k in keypoints]
873
+ if use_normalized_coordinates:
874
+ keypoints_x = tuple([im_width * x for x in keypoints_x])
875
+ keypoints_y = tuple([im_height * y for y in keypoints_y])
876
+ if keypoint_scores is not None:
877
+ keypoint_scores = np.array(keypoint_scores)
878
+ valid_kpt = np.greater(keypoint_scores, min_score_thresh)
879
+ else:
880
+ valid_kpt = np.where(np.any(np.isnan(keypoints), axis=1),
881
+ np.zeros_like(keypoints[:, 0]),
882
+ np.ones_like(keypoints[:, 0]))
883
+ valid_kpt = [v for v in valid_kpt]
884
+
885
+ for keypoint_x, keypoint_y, valid in zip(keypoints_x, keypoints_y, valid_kpt):
886
+ if valid:
887
+ draw.ellipse([(keypoint_x - radius, keypoint_y - radius),
888
+ (keypoint_x + radius, keypoint_y + radius)],
889
+ outline=color, fill=color)
890
+ if keypoint_edges is not None:
891
+ for keypoint_start, keypoint_end in keypoint_edges:
892
+ if (keypoint_start < 0 or keypoint_start >= len(keypoints) or
893
+ keypoint_end < 0 or keypoint_end >= len(keypoints)):
894
+ continue
895
+ if not (valid_kpt[keypoint_start] and valid_kpt[keypoint_end]):
896
+ continue
897
+ edge_coordinates = [
898
+ keypoints_x[keypoint_start], keypoints_y[keypoint_start],
899
+ keypoints_x[keypoint_end], keypoints_y[keypoint_end]
900
+ ]
901
+ draw.line(
902
+ edge_coordinates, fill=keypoint_edge_color, width=keypoint_edge_width)
903
+
904
+
905
+ def draw_mask_on_image_array(image, mask, color='red', alpha=0.4):
906
+ """Draws mask on an image.
907
+
908
+ Args:
909
+ image: uint8 numpy array with shape (img_height, img_height, 3)
910
+ mask: a uint8 numpy array of shape (img_height, img_height) with
911
+ values between either 0 or 1.
912
+ color: color to draw the keypoints with. Default is red.
913
+ alpha: transparency value between 0 and 1. (default: 0.4)
914
+
915
+ Raises:
916
+ ValueError: On incorrect data type for image or masks.
917
+ """
918
+ if image.dtype != np.uint8:
919
+ raise ValueError('`image` not of type np.uint8')
920
+ if mask.dtype != np.uint8:
921
+ raise ValueError('`mask` not of type np.uint8')
922
+ if np.any(np.logical_and(mask != 1, mask != 0)):
923
+ raise ValueError('`mask` elements should be in [0, 1]')
924
+ if image.shape[:2] != mask.shape:
925
+ raise ValueError('The image has spatial dimensions %s but the mask has '
926
+ 'dimensions %s' % (image.shape[:2], mask.shape))
927
+ rgb = ImageColor.getrgb(color)
928
+ pil_image = Image.fromarray(image)
929
+
930
+ solid_color = np.expand_dims(
931
+ np.ones_like(mask), axis=2) * np.reshape(list(rgb), [1, 1, 3])
932
+ pil_solid_color = Image.fromarray(np.uint8(solid_color)).convert('RGBA')
933
+ pil_mask = Image.fromarray(np.uint8(255.0*alpha*mask)).convert('L')
934
+ pil_image = Image.composite(pil_solid_color, pil_image, pil_mask)
935
+ np.copyto(image, np.array(pil_image.convert('RGB')))
936
+
937
+
938
+ def visualize_boxes_and_labels_on_image_array(
939
+ image,
940
+ boxes,
941
+ classes,
942
+ scores,
943
+ category_index,
944
+ instance_masks=None,
945
+ instance_boundaries=None,
946
+ keypoints=None,
947
+ keypoint_scores=None,
948
+ keypoint_edges=None,
949
+ track_ids=None,
950
+ use_normalized_coordinates=False,
951
+ max_boxes_to_draw=20,
952
+ min_score_thresh=.5,
953
+ agnostic_mode=False,
954
+ line_thickness=4,
955
+ groundtruth_box_visualization_color='black',
956
+ skip_boxes=False,
957
+ skip_scores=False,
958
+ skip_labels=False,
959
+ skip_track_ids=False):
960
+ """Overlay labeled boxes on an image with formatted scores and label names.
961
+
962
+ This function groups boxes that correspond to the same location
963
+ and creates a display string for each detection and overlays these
964
+ on the image. Note that this function modifies the image in place, and returns
965
+ that same image.
966
+
967
+ Args:
968
+ image: uint8 numpy array with shape (img_height, img_width, 3)
969
+ boxes: a numpy array of shape [N, 4]
970
+ classes: a numpy array of shape [N]. Note that class indices are 1-based,
971
+ and match the keys in the label map.
972
+ scores: a numpy array of shape [N] or None. If scores=None, then
973
+ this function assumes that the boxes to be plotted are groundtruth
974
+ boxes and plot all boxes as black with no classes or scores.
975
+ category_index: a dict containing category dictionaries (each holding
976
+ category index `id` and category name `name`) keyed by category indices.
977
+ instance_masks: a numpy array of shape [N, image_height, image_width] with
978
+ values ranging between 0 and 1, can be None.
979
+ instance_boundaries: a numpy array of shape [N, image_height, image_width]
980
+ with values ranging between 0 and 1, can be None.
981
+ keypoints: a numpy array of shape [N, num_keypoints, 2], can
982
+ be None.
983
+ keypoint_scores: a numpy array of shape [N, num_keypoints], can be None.
984
+ keypoint_edges: A list of tuples with keypoint indices that specify which
985
+ keypoints should be connected by an edge, e.g. [(0, 1), (2, 4)] draws
986
+ edges from keypoint 0 to 1 and from keypoint 2 to 4.
987
+ track_ids: a numpy array of shape [N] with unique track ids. If provided,
988
+ color-coding of boxes will be determined by these ids, and not the class
989
+ indices.
990
+ use_normalized_coordinates: whether boxes is to be interpreted as
991
+ normalized coordinates or not.
992
+ max_boxes_to_draw: maximum number of boxes to visualize. If None, draw
993
+ all boxes.
994
+ min_score_thresh: minimum score threshold for a box or keypoint to be
995
+ visualized.
996
+ agnostic_mode: boolean (default: False) controlling whether to evaluate in
997
+ class-agnostic mode or not. This mode will display scores but ignore
998
+ classes.
999
+ line_thickness: integer (default: 4) controlling line width of the boxes.
1000
+ groundtruth_box_visualization_color: box color for visualizing groundtruth
1001
+ boxes
1002
+ skip_boxes: whether to skip the drawing of bounding boxes.
1003
+ skip_scores: whether to skip score when drawing a single detection
1004
+ skip_labels: whether to skip label when drawing a single detection
1005
+ skip_track_ids: whether to skip track id when drawing a single detection
1006
+
1007
+ Returns:
1008
+ uint8 numpy array with shape (img_height, img_width, 3) with overlaid boxes.
1009
+ """
1010
+ # Create a display string (and color) for every box location, group any boxes
1011
+ # that correspond to the same location.
1012
+ box_to_display_str_map = collections.defaultdict(list)
1013
+ box_to_color_map = collections.defaultdict(str)
1014
+ box_to_instance_masks_map = {}
1015
+ box_to_instance_boundaries_map = {}
1016
+ box_to_keypoints_map = collections.defaultdict(list)
1017
+ box_to_keypoint_scores_map = collections.defaultdict(list)
1018
+ box_to_track_ids_map = {}
1019
+ if not max_boxes_to_draw:
1020
+ max_boxes_to_draw = boxes.shape[0]
1021
+ for i in range(boxes.shape[0]):
1022
+ if max_boxes_to_draw == len(box_to_color_map):
1023
+ break
1024
+ if scores is None or scores[i] > min_score_thresh:
1025
+ box = tuple(boxes[i].tolist())
1026
+ if instance_masks is not None:
1027
+ box_to_instance_masks_map[box] = instance_masks[i]
1028
+ if instance_boundaries is not None:
1029
+ box_to_instance_boundaries_map[box] = instance_boundaries[i]
1030
+ if keypoints is not None:
1031
+ box_to_keypoints_map[box].extend(keypoints[i])
1032
+ if keypoint_scores is not None:
1033
+ box_to_keypoint_scores_map[box].extend(keypoint_scores[i])
1034
+ if track_ids is not None:
1035
+ box_to_track_ids_map[box] = track_ids[i]
1036
+ if scores is None:
1037
+ box_to_color_map[box] = groundtruth_box_visualization_color
1038
+ else:
1039
+ display_str = ''
1040
+ if not skip_labels:
1041
+ if not agnostic_mode:
1042
+ if classes[i] in six.viewkeys(category_index):
1043
+ class_name = category_index[classes[i]]['name']
1044
+ else:
1045
+ class_name = 'N/A'
1046
+ display_str = str(class_name)
1047
+ if not skip_scores:
1048
+ if not display_str:
1049
+ display_str = '{}%'.format(round(100*scores[i]))
1050
+ else:
1051
+ display_str = '{}: {}%'.format(display_str, round(100*scores[i]))
1052
+ if not skip_track_ids and track_ids is not None:
1053
+ if not display_str:
1054
+ display_str = 'ID {}'.format(track_ids[i])
1055
+ else:
1056
+ display_str = '{}: ID {}'.format(display_str, track_ids[i])
1057
+ box_to_display_str_map[box].append(display_str)
1058
+ if agnostic_mode:
1059
+ box_to_color_map[box] = 'DarkOrange'
1060
+ elif track_ids is not None:
1061
+ prime_multipler = _get_multiplier_for_color_randomness()
1062
+ box_to_color_map[box] = STANDARD_COLORS[
1063
+ (prime_multipler * track_ids[i]) % len(STANDARD_COLORS)]
1064
+ else:
1065
+ box_to_color_map[box] = STANDARD_COLORS[
1066
+ classes[i] % len(STANDARD_COLORS)]
1067
+
1068
+ # Draw all boxes onto image.
1069
+ for box, color in box_to_color_map.items():
1070
+ ymin, xmin, ymax, xmax = box
1071
+ #print("Box---------------->",box)
1072
+ if instance_masks is not None:
1073
+ draw_mask_on_image_array(
1074
+ image,
1075
+ box_to_instance_masks_map[box],
1076
+ color=color
1077
+ )
1078
+ if instance_boundaries is not None:
1079
+ draw_mask_on_image_array(
1080
+ image,
1081
+ box_to_instance_boundaries_map[box],
1082
+ color='red',
1083
+ alpha=1.0
1084
+ )
1085
+ draw_bounding_box_on_image_array(
1086
+ image,
1087
+ ymin,
1088
+ xmin,
1089
+ ymax,
1090
+ xmax,
1091
+ color=color,
1092
+ thickness=0 if skip_boxes else line_thickness,
1093
+ display_str_list=box_to_display_str_map[box],
1094
+ use_normalized_coordinates=use_normalized_coordinates)
1095
+ if keypoints is not None:
1096
+ keypoint_scores_for_box = None
1097
+ if box_to_keypoint_scores_map:
1098
+ keypoint_scores_for_box = box_to_keypoint_scores_map[box]
1099
+ draw_keypoints_on_image_array(
1100
+ image,
1101
+ box_to_keypoints_map[box],
1102
+ keypoint_scores_for_box,
1103
+ min_score_thresh=min_score_thresh,
1104
+ color=color,
1105
+ radius=line_thickness / 2,
1106
+ use_normalized_coordinates=use_normalized_coordinates,
1107
+ keypoint_edges=keypoint_edges,
1108
+ keypoint_edge_color=color,
1109
+ keypoint_edge_width=line_thickness // 2)
1110
+
1111
+ return image
1112
+
1113
+
1114
+ def add_cdf_image_summary(values, name):
1115
+ """Adds a tf.summary.image for a CDF plot of the values.
1116
+
1117
+ Normalizes `values` such that they sum to 1, plots the cumulative distribution
1118
+ function and creates a tf image summary.
1119
+
1120
+ Args:
1121
+ values: a 1-D float32 tensor containing the values.
1122
+ name: name for the image summary.
1123
+ """
1124
+ def cdf_plot(values):
1125
+ """Numpy function to plot CDF."""
1126
+ normalized_values = values / np.sum(values)
1127
+ sorted_values = np.sort(normalized_values)
1128
+ cumulative_values = np.cumsum(sorted_values)
1129
+ fraction_of_examples = (np.arange(cumulative_values.size, dtype=np.float32)
1130
+ / cumulative_values.size)
1131
+ fig = plt.figure(frameon=False)
1132
+ ax = fig.add_subplot('111')
1133
+ ax.plot(fraction_of_examples, cumulative_values)
1134
+ ax.set_ylabel('cumulative normalized values')
1135
+ ax.set_xlabel('fraction of examples')
1136
+ fig.canvas.draw()
1137
+ width, height = fig.get_size_inches() * fig.get_dpi()
1138
+ image = np.fromstring(fig.canvas.tostring_rgb(), dtype='uint8').reshape(
1139
+ 1, int(height), int(width), 3)
1140
+ return image
1141
+ cdf_plot = tf.py_func(cdf_plot, [values], tf.uint8)
1142
+ tf.summary.image(name, cdf_plot)
1143
+
1144
+
1145
+ def add_hist_image_summary(values, bins, name):
1146
+ """Adds a tf.summary.image for a histogram plot of the values.
1147
+
1148
+ Plots the histogram of values and creates a tf image summary.
1149
+
1150
+ Args:
1151
+ values: a 1-D float32 tensor containing the values.
1152
+ bins: bin edges which will be directly passed to np.histogram.
1153
+ name: name for the image summary.
1154
+ """
1155
+
1156
+ def hist_plot(values, bins):
1157
+ """Numpy function to plot hist."""
1158
+ fig = plt.figure(frameon=False)
1159
+ ax = fig.add_subplot('111')
1160
+ y, x = np.histogram(values, bins=bins)
1161
+ ax.plot(x[:-1], y)
1162
+ ax.set_ylabel('count')
1163
+ ax.set_xlabel('value')
1164
+ fig.canvas.draw()
1165
+ width, height = fig.get_size_inches() * fig.get_dpi()
1166
+ image = np.fromstring(
1167
+ fig.canvas.tostring_rgb(), dtype='uint8').reshape(
1168
+ 1, int(height), int(width), 3)
1169
+ return image
1170
+ hist_plot = tf.py_func(hist_plot, [values, bins], tf.uint8)
1171
+ tf.summary.image(name, hist_plot)
1172
+
1173
+
1174
+ class EvalMetricOpsVisualization(six.with_metaclass(abc.ABCMeta, object)):
1175
+ """Abstract base class responsible for visualizations during evaluation.
1176
+
1177
+ Currently, summary images are not run during evaluation. One way to produce
1178
+ evaluation images in Tensorboard is to provide tf.summary.image strings as
1179
+ `value_ops` in tf.estimator.EstimatorSpec's `eval_metric_ops`. This class is
1180
+ responsible for accruing images (with overlaid detections and groundtruth)
1181
+ and returning a dictionary that can be passed to `eval_metric_ops`.
1182
+ """
1183
+
1184
+ def __init__(self,
1185
+ category_index,
1186
+ max_examples_to_draw=5,
1187
+ max_boxes_to_draw=20,
1188
+ min_score_thresh=0.2,
1189
+ use_normalized_coordinates=True,
1190
+ summary_name_prefix='evaluation_image',
1191
+ keypoint_edges=None):
1192
+ """Creates an EvalMetricOpsVisualization.
1193
+
1194
+ Args:
1195
+ category_index: A category index (dictionary) produced from a labelmap.
1196
+ max_examples_to_draw: The maximum number of example summaries to produce.
1197
+ max_boxes_to_draw: The maximum number of boxes to draw for detections.
1198
+ min_score_thresh: The minimum score threshold for showing detections.
1199
+ use_normalized_coordinates: Whether to assume boxes and keypoints are in
1200
+ normalized coordinates (as opposed to absolute coordinates).
1201
+ Default is True.
1202
+ summary_name_prefix: A string prefix for each image summary.
1203
+ keypoint_edges: A list of tuples with keypoint indices that specify which
1204
+ keypoints should be connected by an edge, e.g. [(0, 1), (2, 4)] draws
1205
+ edges from keypoint 0 to 1 and from keypoint 2 to 4.
1206
+ """
1207
+
1208
+ self._category_index = category_index
1209
+ self._max_examples_to_draw = max_examples_to_draw
1210
+ self._max_boxes_to_draw = max_boxes_to_draw
1211
+ self._min_score_thresh = min_score_thresh
1212
+ self._use_normalized_coordinates = use_normalized_coordinates
1213
+ self._summary_name_prefix = summary_name_prefix
1214
+ self._keypoint_edges = keypoint_edges
1215
+ self._images = []
1216
+
1217
+ def clear(self):
1218
+ self._images = []
1219
+
1220
+ def add_images(self, images):
1221
+ """Store a list of images, each with shape [1, H, W, C]."""
1222
+ if len(self._images) >= self._max_examples_to_draw:
1223
+ return
1224
+
1225
+ # Store images and clip list if necessary.
1226
+ self._images.extend(images)
1227
+ if len(self._images) > self._max_examples_to_draw:
1228
+ self._images[self._max_examples_to_draw:] = []
1229
+
1230
+ def get_estimator_eval_metric_ops(self, eval_dict):
1231
+ """Returns metric ops for use in tf.estimator.EstimatorSpec.
1232
+
1233
+ Args:
1234
+ eval_dict: A dictionary that holds an image, groundtruth, and detections
1235
+ for a batched example. Note that, we use only the first example for
1236
+ visualization. See eval_util.result_dict_for_batched_example() for a
1237
+ convenient method for constructing such a dictionary. The dictionary
1238
+ contains
1239
+ fields.InputDataFields.original_image: [batch_size, H, W, 3] image.
1240
+ fields.InputDataFields.original_image_spatial_shape: [batch_size, 2]
1241
+ tensor containing the size of the original image.
1242
+ fields.InputDataFields.true_image_shape: [batch_size, 3]
1243
+ tensor containing the spatial size of the upadded original image.
1244
+ fields.InputDataFields.groundtruth_boxes - [batch_size, num_boxes, 4]
1245
+ float32 tensor with groundtruth boxes in range [0.0, 1.0].
1246
+ fields.InputDataFields.groundtruth_classes - [batch_size, num_boxes]
1247
+ int64 tensor with 1-indexed groundtruth classes.
1248
+ fields.InputDataFields.groundtruth_instance_masks - (optional)
1249
+ [batch_size, num_boxes, H, W] int64 tensor with instance masks.
1250
+ fields.InputDataFields.groundtruth_keypoints - (optional)
1251
+ [batch_size, num_boxes, num_keypoints, 2] float32 tensor with
1252
+ keypoint coordinates in format [y, x].
1253
+ fields.InputDataFields.groundtruth_keypoint_visibilities - (optional)
1254
+ [batch_size, num_boxes, num_keypoints] bool tensor with
1255
+ keypoint visibilities.
1256
+ fields.DetectionResultFields.detection_boxes - [batch_size,
1257
+ max_num_boxes, 4] float32 tensor with detection boxes in range [0.0,
1258
+ 1.0].
1259
+ fields.DetectionResultFields.detection_classes - [batch_size,
1260
+ max_num_boxes] int64 tensor with 1-indexed detection classes.
1261
+ fields.DetectionResultFields.detection_scores - [batch_size,
1262
+ max_num_boxes] float32 tensor with detection scores.
1263
+ fields.DetectionResultFields.detection_masks - (optional) [batch_size,
1264
+ max_num_boxes, H, W] float32 tensor of binarized masks.
1265
+ fields.DetectionResultFields.detection_keypoints - (optional)
1266
+ [batch_size, max_num_boxes, num_keypoints, 2] float32 tensor with
1267
+ keypoints.
1268
+ fields.DetectionResultFields.detection_keypoint_scores - (optional)
1269
+ [batch_size, max_num_boxes, num_keypoints] float32 tensor with
1270
+ keypoints scores.
1271
+
1272
+ Returns:
1273
+ A dictionary of image summary names to tuple of (value_op, update_op). The
1274
+ `update_op` is the same for all items in the dictionary, and is
1275
+ responsible for saving a single side-by-side image with detections and
1276
+ groundtruth. Each `value_op` holds the tf.summary.image string for a given
1277
+ image.
1278
+ """
1279
+ if self._max_examples_to_draw == 0:
1280
+ return {}
1281
+ images = self.images_from_evaluation_dict(eval_dict)
1282
+
1283
+ def get_images():
1284
+ """Returns a list of images, padded to self._max_images_to_draw."""
1285
+ images = self._images
1286
+ while len(images) < self._max_examples_to_draw:
1287
+ images.append(np.array(0, dtype=np.uint8))
1288
+ self.clear()
1289
+ return images
1290
+
1291
+ def image_summary_or_default_string(summary_name, image):
1292
+ """Returns image summaries for non-padded elements."""
1293
+ return tf.cond(
1294
+ tf.equal(tf.size(tf.shape(image)), 4),
1295
+ lambda: tf.summary.image(summary_name, image),
1296
+ lambda: tf.constant(''))
1297
+
1298
+ if tf.executing_eagerly():
1299
+ update_op = self.add_images([[images[0]]])
1300
+ image_tensors = get_images()
1301
+ else:
1302
+ update_op = tf.py_func(self.add_images, [[images[0]]], [])
1303
+ image_tensors = tf.py_func(
1304
+ get_images, [], [tf.uint8] * self._max_examples_to_draw)
1305
+ eval_metric_ops = {}
1306
+ for i, image in enumerate(image_tensors):
1307
+ summary_name = self._summary_name_prefix + '/' + str(i)
1308
+ value_op = image_summary_or_default_string(summary_name, image)
1309
+ eval_metric_ops[summary_name] = (value_op, update_op)
1310
+ return eval_metric_ops
1311
+
1312
+ @abc.abstractmethod
1313
+ def images_from_evaluation_dict(self, eval_dict):
1314
+ """Converts evaluation dictionary into a list of image tensors.
1315
+
1316
+ To be overridden by implementations.
1317
+
1318
+ Args:
1319
+ eval_dict: A dictionary with all the necessary information for producing
1320
+ visualizations.
1321
+
1322
+ Returns:
1323
+ A list of [1, H, W, C] uint8 tensors.
1324
+ """
1325
+ raise NotImplementedError
1326
+
1327
+
1328
+ class VisualizeSingleFrameDetections(EvalMetricOpsVisualization):
1329
+ """Class responsible for single-frame object detection visualizations."""
1330
+
1331
+ def __init__(self,
1332
+ category_index,
1333
+ max_examples_to_draw=5,
1334
+ max_boxes_to_draw=20,
1335
+ min_score_thresh=0.2,
1336
+ use_normalized_coordinates=True,
1337
+ summary_name_prefix='Detections_Left_Groundtruth_Right',
1338
+ keypoint_edges=None):
1339
+ super(VisualizeSingleFrameDetections, self).__init__(
1340
+ category_index=category_index,
1341
+ max_examples_to_draw=max_examples_to_draw,
1342
+ max_boxes_to_draw=max_boxes_to_draw,
1343
+ min_score_thresh=min_score_thresh,
1344
+ use_normalized_coordinates=use_normalized_coordinates,
1345
+ summary_name_prefix=summary_name_prefix,
1346
+ keypoint_edges=keypoint_edges)
1347
+
1348
+ def images_from_evaluation_dict(self, eval_dict):
1349
+ return draw_side_by_side_evaluation_image(eval_dict, self._category_index,
1350
+ self._max_boxes_to_draw,
1351
+ self._min_score_thresh,
1352
+ self._use_normalized_coordinates,
1353
+ self._keypoint_edges)