RobotJelly commited on
Commit
ada4657
1 Parent(s): 4b7383b
Files changed (1) hide show
  1. app.py +309 -0
app.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from huggingface_hub import from_pretrained_keras
3
+ from PIL import Image
4
+ import io
5
+ import matplotlib.pyplot as plt
6
+ import os
7
+ import re
8
+ import zipfile
9
+ import numpy as np
10
+ import tensorflow as tf
11
+ from tensorflow import keras
12
+
13
+ coco_image = []
14
+ coco_dir = 'coco/images/test2017/'
15
+ for idx, images in enumerate(os.listdir(coco_dir)):
16
+ image = os.path.join(coco_dir, images)
17
+ if os.path.isfile(image) and idx < 10:
18
+ coco_image.append(image)
19
+
20
+ class AnchorBox:
21
+ """Generates anchor boxes.
22
+
23
+ This class has operations to generate anchor boxes for feature maps at
24
+ strides `[8, 16, 32, 64, 128]`. Where each anchor each box is of the
25
+ format `[x, y, width, height]`.
26
+
27
+ Attributes:
28
+ aspect_ratios: A list of float values representing the aspect ratios of
29
+ the anchor boxes at each location on the feature map
30
+ scales: A list of float values representing the scale of the anchor boxes
31
+ at each location on the feature map.
32
+ num_anchors: The number of anchor boxes at each location on feature map
33
+ areas: A list of float values representing the areas of the anchor
34
+ boxes for each feature map in the feature pyramid.
35
+ strides: A list of float value representing the strides for each feature
36
+ map in the feature pyramid.
37
+ """
38
+
39
+ def __init__(self):
40
+ self.aspect_ratios = [0.5, 1.0, 2.0]
41
+ self.scales = [2 ** x for x in [0, 1 / 3, 2 / 3]]
42
+
43
+ self._num_anchors = len(self.aspect_ratios) * len(self.scales)
44
+ self._strides = [2 ** i for i in range(3, 8)]
45
+ self._areas = [x ** 2 for x in [32.0, 64.0, 128.0, 256.0, 512.0]]
46
+ self._anchor_dims = self._compute_dims()
47
+
48
+ def _compute_dims(self):
49
+ """Computes anchor box dimensions for all ratios and scales at all levels
50
+ of the feature pyramid.
51
+ """
52
+ anchor_dims_all = []
53
+ for area in self._areas:
54
+ anchor_dims = []
55
+ for ratio in self.aspect_ratios:
56
+ anchor_height = tf.math.sqrt(area / ratio)
57
+ anchor_width = area / anchor_height
58
+ dims = tf.reshape(
59
+ tf.stack([anchor_width, anchor_height], axis=-1), [1, 1, 2]
60
+ )
61
+ for scale in self.scales:
62
+ anchor_dims.append(scale * dims)
63
+ anchor_dims_all.append(tf.stack(anchor_dims, axis=-2))
64
+ return anchor_dims_all
65
+
66
+ def _get_anchors(self, feature_height, feature_width, level):
67
+ """Generates anchor boxes for a given feature map size and level
68
+
69
+ Arguments:
70
+ feature_height: An integer representing the height of the feature map.
71
+ feature_width: An integer representing the width of the feature map.
72
+ level: An integer representing the level of the feature map in the
73
+ feature pyramid.
74
+
75
+ Returns:
76
+ anchor boxes with the shape
77
+ `(feature_height * feature_width * num_anchors, 4)`
78
+ """
79
+ rx = tf.range(feature_width, dtype=tf.float32) + 0.5
80
+ ry = tf.range(feature_height, dtype=tf.float32) + 0.5
81
+ centers = tf.stack(tf.meshgrid(rx, ry), axis=-1) * self._strides[level - 3]
82
+ centers = tf.expand_dims(centers, axis=-2)
83
+ centers = tf.tile(centers, [1, 1, self._num_anchors, 1])
84
+ dims = tf.tile(
85
+ self._anchor_dims[level - 3], [feature_height, feature_width, 1, 1]
86
+ )
87
+ anchors = tf.concat([centers, dims], axis=-1)
88
+ return tf.reshape(
89
+ anchors, [feature_height * feature_width * self._num_anchors, 4]
90
+ )
91
+
92
+ def get_anchors(self, image_height, image_width):
93
+ """Generates anchor boxes for all the feature maps of the feature pyramid.
94
+
95
+ Arguments:
96
+ image_height: Height of the input image.
97
+ image_width: Width of the input image.
98
+
99
+ Returns:
100
+ anchor boxes for all the feature maps, stacked as a single tensor
101
+ with shape `(total_anchors, 4)`
102
+ """
103
+ anchors = [
104
+ self._get_anchors(
105
+ tf.math.ceil(image_height / 2 ** i),
106
+ tf.math.ceil(image_width / 2 ** i),
107
+ i,
108
+ )
109
+ for i in range(3, 8)
110
+ ]
111
+ return tf.concat(anchors, axis=0)
112
+
113
+ class DecodePredictions(tf.keras.layers.Layer):
114
+ """A Keras layer that decodes predictions of the RetinaNet model.
115
+
116
+ Attributes:
117
+ num_classes: Number of classes in the dataset
118
+ confidence_threshold: Minimum class probability, below which detections
119
+ are pruned.
120
+ nms_iou_threshold: IOU threshold for the NMS operation
121
+ max_detections_per_class: Maximum number of detections to retain per
122
+ class.
123
+ max_detections: Maximum number of detections to retain across all
124
+ classes.
125
+ box_variance: The scaling factors used to scale the bounding box
126
+ predictions.
127
+ """
128
+
129
+ def __init__(
130
+ self,
131
+ num_classes=80,
132
+ confidence_threshold=0.05,
133
+ nms_iou_threshold=0.5,
134
+ max_detections_per_class=100,
135
+ max_detections=100,
136
+ box_variance=[0.1, 0.1, 0.2, 0.2],
137
+ **kwargs
138
+ ):
139
+ super(DecodePredictions, self).__init__(**kwargs)
140
+ self.num_classes = num_classes
141
+ self.confidence_threshold = confidence_threshold
142
+ self.nms_iou_threshold = nms_iou_threshold
143
+ self.max_detections_per_class = max_detections_per_class
144
+ self.max_detections = max_detections
145
+
146
+ self._anchor_box = AnchorBox()
147
+ self._box_variance = tf.convert_to_tensor(
148
+ [0.1, 0.1, 0.2, 0.2], dtype=tf.float32
149
+ )
150
+
151
+ def _decode_box_predictions(self, anchor_boxes, box_predictions):
152
+ boxes = box_predictions * self._box_variance
153
+ boxes = tf.concat(
154
+ [
155
+ boxes[:, :, :2] * anchor_boxes[:, :, 2:] + anchor_boxes[:, :, :2],
156
+ tf.math.exp(boxes[:, :, 2:]) * anchor_boxes[:, :, 2:],
157
+ ],
158
+ axis=-1,
159
+ )
160
+ boxes_transformed = convert_to_corners(boxes)
161
+ return boxes_transformed
162
+
163
+ def call(self, images, predictions):
164
+ image_shape = tf.cast(tf.shape(images), dtype=tf.float32)
165
+ anchor_boxes = self._anchor_box.get_anchors(image_shape[1], image_shape[2])
166
+ box_predictions = predictions[:, :, :4]
167
+ cls_predictions = tf.nn.sigmoid(predictions[:, :, 4:])
168
+ boxes = self._decode_box_predictions(anchor_boxes[None, ...], box_predictions)
169
+
170
+ return tf.image.combined_non_max_suppression(
171
+ tf.expand_dims(boxes, axis=2),
172
+ cls_predictions,
173
+ self.max_detections_per_class,
174
+ self.max_detections,
175
+ self.nms_iou_threshold,
176
+ self.confidence_threshold,
177
+ clip_boxes=False,
178
+ )
179
+
180
+ def convert_to_corners(boxes):
181
+ """Changes the box format to corner coordinates
182
+
183
+ Arguments:
184
+ boxes: A tensor of rank 2 or higher with a shape of `(..., num_boxes, 4)`
185
+ representing bounding boxes where each box is of the format
186
+ `[x, y, width, height]`.
187
+
188
+ Returns:
189
+ converted boxes with shape same as that of boxes.
190
+ """
191
+ return tf.concat(
192
+ [boxes[..., :2] - boxes[..., 2:] / 2.0, boxes[..., :2] + boxes[..., 2:] / 2.0],
193
+ axis=-1,
194
+ )
195
+
196
+ def resize_and_pad_image(
197
+ image, min_side=800.0, max_side=1333.0, jitter=[640, 1024], stride=128.0
198
+ ):
199
+ """Resizes and pads image while preserving aspect ratio.
200
+
201
+ 1. Resizes images so that the shorter side is equal to `min_side`
202
+ 2. If the longer side is greater than `max_side`, then resize the image
203
+ with longer side equal to `max_side`
204
+ 3. Pad with zeros on right and bottom to make the image shape divisible by
205
+ `stride`
206
+
207
+ Arguments:
208
+ image: A 3-D tensor of shape `(height, width, channels)` representing an
209
+ image.
210
+ min_side: The shorter side of the image is resized to this value, if
211
+ `jitter` is set to None.
212
+ max_side: If the longer side of the image exceeds this value after
213
+ resizing, the image is resized such that the longer side now equals to
214
+ this value.
215
+ jitter: A list of floats containing minimum and maximum size for scale
216
+ jittering. If available, the shorter side of the image will be
217
+ resized to a random value in this range.
218
+ stride: The stride of the smallest feature map in the feature pyramid.
219
+ Can be calculated using `image_size / feature_map_size`.
220
+
221
+ Returns:
222
+ image: Resized and padded image.
223
+ image_shape: Shape of the image before padding.
224
+ ratio: The scaling factor used to resize the image
225
+ """
226
+ image_shape = tf.cast(tf.shape(image)[:2], dtype=tf.float32)
227
+ if jitter is not None:
228
+ min_side = tf.random.uniform((), jitter[0], jitter[1], dtype=tf.float32)
229
+ ratio = min_side / tf.reduce_min(image_shape)
230
+ if ratio * tf.reduce_max(image_shape) > max_side:
231
+ ratio = max_side / tf.reduce_max(image_shape)
232
+ image_shape = ratio * image_shape
233
+ image = tf.image.resize(image, tf.cast(image_shape, dtype=tf.int32))
234
+ padded_image_shape = tf.cast(
235
+ tf.math.ceil(image_shape / stride) * stride, dtype=tf.int32
236
+ )
237
+ image = tf.image.pad_to_bounding_box(
238
+ image, 0, 0, padded_image_shape[0], padded_image_shape[1]
239
+ )
240
+ return image, image_shape, ratio
241
+
242
+ def visualize_detections(
243
+ image, boxes, classes, scores, figsize=(7, 7), linewidth=1, color=[0, 0, 1]
244
+ ):
245
+ """Visualize Detections"""
246
+ image = np.array(image, dtype=np.uint8)
247
+ plt.figure(figsize=figsize)
248
+ plt.axis("off")
249
+ plt.imshow(image)
250
+ ax = plt.gca()
251
+ for box, _cls, score in zip(boxes, classes, scores):
252
+ text = "{}: {:.2f}".format(_cls, score)
253
+ x1, y1, x2, y2 = box
254
+ w, h = x2 - x1, y2 - y1
255
+ patch = plt.Rectangle(
256
+ [x1, y1], w, h, fill=False, edgecolor=color, linewidth=linewidth
257
+ )
258
+ ax.add_patch(patch)
259
+ ax.text(
260
+ x1,
261
+ y1,
262
+ text,
263
+ bbox={"facecolor": color, "alpha": 0.4},
264
+ clip_box=ax.clipbox,
265
+ clip_on=True,
266
+ )
267
+ plt.show()
268
+ return ax
269
+
270
+ def prepare_image(image):
271
+ image, _, ratio = resize_and_pad_image(image, jitter=None)
272
+ image = tf.keras.applications.resnet.preprocess_input(image)
273
+ return tf.expand_dims(image, axis=0), ratio
274
+
275
+ model = from_pretrained_keras("keras-io/Object-Detection-RetinaNet")
276
+ img_input = tf.keras.Input(shape=[None, None, 3], name="image")
277
+ predictions = model(img_input, training=False)
278
+ detections = DecodePredictions(confidence_threshold=0.5)(img_input, predictions)
279
+ inference_model = tf.keras.Model(inputs=img_input, outputs=detections)
280
+
281
+ def predict(image):
282
+ input_image, ratio = prepare_image(image)
283
+ detections = inference_model.predict(input_image)
284
+ num_detections = detections.valid_detections[0]
285
+ class_names = [
286
+ int2str(int(x)) for x in detections.nmsed_classes[0][:num_detections]
287
+ ]
288
+ img_buf = io.BytesIO()
289
+ ax = visualize_detections(
290
+ image,
291
+ detections.nmsed_boxes[0][:num_detections] / ratio,
292
+ class_names,
293
+ detections.nmsed_scores[0][:num_detections],
294
+ )
295
+ ax.figure.savefig(img_buf)
296
+ img_buf.seek(0)
297
+ img = Image.open(img_buf)
298
+ return img
299
+
300
+ # Input
301
+ input = gr.inputs.Image(image_mode="RGB", type="numpy", label="Enter Object Image")
302
+
303
+ # Output
304
+ output = gr.outputs.Image(type="pil", label="Detected Objects with Class Category")
305
+
306
+ title = "Object Detection With RetinaNet"
307
+ description = "Upload an Image or take one from examples to localize objects present in an image, and at the same time, classify them into different categories"
308
+
309
+ gr.Interface(fn=predict, inputs = input, outputs = output, examples=coco_image, allow_flagging=False, analytics_enabled=False, title=title, description=description).launch(enable_queue=True, debug=True)