Spaces:
Runtime error
Runtime error
RobotJelly
commited on
Commit
•
ada4657
1
Parent(s):
4b7383b
app.py
Browse files
app.py
ADDED
@@ -0,0 +1,309 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from huggingface_hub import from_pretrained_keras
|
3 |
+
from PIL import Image
|
4 |
+
import io
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
import os
|
7 |
+
import re
|
8 |
+
import zipfile
|
9 |
+
import numpy as np
|
10 |
+
import tensorflow as tf
|
11 |
+
from tensorflow import keras
|
12 |
+
|
13 |
+
coco_image = []
|
14 |
+
coco_dir = 'coco/images/test2017/'
|
15 |
+
for idx, images in enumerate(os.listdir(coco_dir)):
|
16 |
+
image = os.path.join(coco_dir, images)
|
17 |
+
if os.path.isfile(image) and idx < 10:
|
18 |
+
coco_image.append(image)
|
19 |
+
|
20 |
+
class AnchorBox:
|
21 |
+
"""Generates anchor boxes.
|
22 |
+
|
23 |
+
This class has operations to generate anchor boxes for feature maps at
|
24 |
+
strides `[8, 16, 32, 64, 128]`. Where each anchor each box is of the
|
25 |
+
format `[x, y, width, height]`.
|
26 |
+
|
27 |
+
Attributes:
|
28 |
+
aspect_ratios: A list of float values representing the aspect ratios of
|
29 |
+
the anchor boxes at each location on the feature map
|
30 |
+
scales: A list of float values representing the scale of the anchor boxes
|
31 |
+
at each location on the feature map.
|
32 |
+
num_anchors: The number of anchor boxes at each location on feature map
|
33 |
+
areas: A list of float values representing the areas of the anchor
|
34 |
+
boxes for each feature map in the feature pyramid.
|
35 |
+
strides: A list of float value representing the strides for each feature
|
36 |
+
map in the feature pyramid.
|
37 |
+
"""
|
38 |
+
|
39 |
+
def __init__(self):
|
40 |
+
self.aspect_ratios = [0.5, 1.0, 2.0]
|
41 |
+
self.scales = [2 ** x for x in [0, 1 / 3, 2 / 3]]
|
42 |
+
|
43 |
+
self._num_anchors = len(self.aspect_ratios) * len(self.scales)
|
44 |
+
self._strides = [2 ** i for i in range(3, 8)]
|
45 |
+
self._areas = [x ** 2 for x in [32.0, 64.0, 128.0, 256.0, 512.0]]
|
46 |
+
self._anchor_dims = self._compute_dims()
|
47 |
+
|
48 |
+
def _compute_dims(self):
|
49 |
+
"""Computes anchor box dimensions for all ratios and scales at all levels
|
50 |
+
of the feature pyramid.
|
51 |
+
"""
|
52 |
+
anchor_dims_all = []
|
53 |
+
for area in self._areas:
|
54 |
+
anchor_dims = []
|
55 |
+
for ratio in self.aspect_ratios:
|
56 |
+
anchor_height = tf.math.sqrt(area / ratio)
|
57 |
+
anchor_width = area / anchor_height
|
58 |
+
dims = tf.reshape(
|
59 |
+
tf.stack([anchor_width, anchor_height], axis=-1), [1, 1, 2]
|
60 |
+
)
|
61 |
+
for scale in self.scales:
|
62 |
+
anchor_dims.append(scale * dims)
|
63 |
+
anchor_dims_all.append(tf.stack(anchor_dims, axis=-2))
|
64 |
+
return anchor_dims_all
|
65 |
+
|
66 |
+
def _get_anchors(self, feature_height, feature_width, level):
|
67 |
+
"""Generates anchor boxes for a given feature map size and level
|
68 |
+
|
69 |
+
Arguments:
|
70 |
+
feature_height: An integer representing the height of the feature map.
|
71 |
+
feature_width: An integer representing the width of the feature map.
|
72 |
+
level: An integer representing the level of the feature map in the
|
73 |
+
feature pyramid.
|
74 |
+
|
75 |
+
Returns:
|
76 |
+
anchor boxes with the shape
|
77 |
+
`(feature_height * feature_width * num_anchors, 4)`
|
78 |
+
"""
|
79 |
+
rx = tf.range(feature_width, dtype=tf.float32) + 0.5
|
80 |
+
ry = tf.range(feature_height, dtype=tf.float32) + 0.5
|
81 |
+
centers = tf.stack(tf.meshgrid(rx, ry), axis=-1) * self._strides[level - 3]
|
82 |
+
centers = tf.expand_dims(centers, axis=-2)
|
83 |
+
centers = tf.tile(centers, [1, 1, self._num_anchors, 1])
|
84 |
+
dims = tf.tile(
|
85 |
+
self._anchor_dims[level - 3], [feature_height, feature_width, 1, 1]
|
86 |
+
)
|
87 |
+
anchors = tf.concat([centers, dims], axis=-1)
|
88 |
+
return tf.reshape(
|
89 |
+
anchors, [feature_height * feature_width * self._num_anchors, 4]
|
90 |
+
)
|
91 |
+
|
92 |
+
def get_anchors(self, image_height, image_width):
|
93 |
+
"""Generates anchor boxes for all the feature maps of the feature pyramid.
|
94 |
+
|
95 |
+
Arguments:
|
96 |
+
image_height: Height of the input image.
|
97 |
+
image_width: Width of the input image.
|
98 |
+
|
99 |
+
Returns:
|
100 |
+
anchor boxes for all the feature maps, stacked as a single tensor
|
101 |
+
with shape `(total_anchors, 4)`
|
102 |
+
"""
|
103 |
+
anchors = [
|
104 |
+
self._get_anchors(
|
105 |
+
tf.math.ceil(image_height / 2 ** i),
|
106 |
+
tf.math.ceil(image_width / 2 ** i),
|
107 |
+
i,
|
108 |
+
)
|
109 |
+
for i in range(3, 8)
|
110 |
+
]
|
111 |
+
return tf.concat(anchors, axis=0)
|
112 |
+
|
113 |
+
class DecodePredictions(tf.keras.layers.Layer):
|
114 |
+
"""A Keras layer that decodes predictions of the RetinaNet model.
|
115 |
+
|
116 |
+
Attributes:
|
117 |
+
num_classes: Number of classes in the dataset
|
118 |
+
confidence_threshold: Minimum class probability, below which detections
|
119 |
+
are pruned.
|
120 |
+
nms_iou_threshold: IOU threshold for the NMS operation
|
121 |
+
max_detections_per_class: Maximum number of detections to retain per
|
122 |
+
class.
|
123 |
+
max_detections: Maximum number of detections to retain across all
|
124 |
+
classes.
|
125 |
+
box_variance: The scaling factors used to scale the bounding box
|
126 |
+
predictions.
|
127 |
+
"""
|
128 |
+
|
129 |
+
def __init__(
|
130 |
+
self,
|
131 |
+
num_classes=80,
|
132 |
+
confidence_threshold=0.05,
|
133 |
+
nms_iou_threshold=0.5,
|
134 |
+
max_detections_per_class=100,
|
135 |
+
max_detections=100,
|
136 |
+
box_variance=[0.1, 0.1, 0.2, 0.2],
|
137 |
+
**kwargs
|
138 |
+
):
|
139 |
+
super(DecodePredictions, self).__init__(**kwargs)
|
140 |
+
self.num_classes = num_classes
|
141 |
+
self.confidence_threshold = confidence_threshold
|
142 |
+
self.nms_iou_threshold = nms_iou_threshold
|
143 |
+
self.max_detections_per_class = max_detections_per_class
|
144 |
+
self.max_detections = max_detections
|
145 |
+
|
146 |
+
self._anchor_box = AnchorBox()
|
147 |
+
self._box_variance = tf.convert_to_tensor(
|
148 |
+
[0.1, 0.1, 0.2, 0.2], dtype=tf.float32
|
149 |
+
)
|
150 |
+
|
151 |
+
def _decode_box_predictions(self, anchor_boxes, box_predictions):
|
152 |
+
boxes = box_predictions * self._box_variance
|
153 |
+
boxes = tf.concat(
|
154 |
+
[
|
155 |
+
boxes[:, :, :2] * anchor_boxes[:, :, 2:] + anchor_boxes[:, :, :2],
|
156 |
+
tf.math.exp(boxes[:, :, 2:]) * anchor_boxes[:, :, 2:],
|
157 |
+
],
|
158 |
+
axis=-1,
|
159 |
+
)
|
160 |
+
boxes_transformed = convert_to_corners(boxes)
|
161 |
+
return boxes_transformed
|
162 |
+
|
163 |
+
def call(self, images, predictions):
|
164 |
+
image_shape = tf.cast(tf.shape(images), dtype=tf.float32)
|
165 |
+
anchor_boxes = self._anchor_box.get_anchors(image_shape[1], image_shape[2])
|
166 |
+
box_predictions = predictions[:, :, :4]
|
167 |
+
cls_predictions = tf.nn.sigmoid(predictions[:, :, 4:])
|
168 |
+
boxes = self._decode_box_predictions(anchor_boxes[None, ...], box_predictions)
|
169 |
+
|
170 |
+
return tf.image.combined_non_max_suppression(
|
171 |
+
tf.expand_dims(boxes, axis=2),
|
172 |
+
cls_predictions,
|
173 |
+
self.max_detections_per_class,
|
174 |
+
self.max_detections,
|
175 |
+
self.nms_iou_threshold,
|
176 |
+
self.confidence_threshold,
|
177 |
+
clip_boxes=False,
|
178 |
+
)
|
179 |
+
|
180 |
+
def convert_to_corners(boxes):
|
181 |
+
"""Changes the box format to corner coordinates
|
182 |
+
|
183 |
+
Arguments:
|
184 |
+
boxes: A tensor of rank 2 or higher with a shape of `(..., num_boxes, 4)`
|
185 |
+
representing bounding boxes where each box is of the format
|
186 |
+
`[x, y, width, height]`.
|
187 |
+
|
188 |
+
Returns:
|
189 |
+
converted boxes with shape same as that of boxes.
|
190 |
+
"""
|
191 |
+
return tf.concat(
|
192 |
+
[boxes[..., :2] - boxes[..., 2:] / 2.0, boxes[..., :2] + boxes[..., 2:] / 2.0],
|
193 |
+
axis=-1,
|
194 |
+
)
|
195 |
+
|
196 |
+
def resize_and_pad_image(
|
197 |
+
image, min_side=800.0, max_side=1333.0, jitter=[640, 1024], stride=128.0
|
198 |
+
):
|
199 |
+
"""Resizes and pads image while preserving aspect ratio.
|
200 |
+
|
201 |
+
1. Resizes images so that the shorter side is equal to `min_side`
|
202 |
+
2. If the longer side is greater than `max_side`, then resize the image
|
203 |
+
with longer side equal to `max_side`
|
204 |
+
3. Pad with zeros on right and bottom to make the image shape divisible by
|
205 |
+
`stride`
|
206 |
+
|
207 |
+
Arguments:
|
208 |
+
image: A 3-D tensor of shape `(height, width, channels)` representing an
|
209 |
+
image.
|
210 |
+
min_side: The shorter side of the image is resized to this value, if
|
211 |
+
`jitter` is set to None.
|
212 |
+
max_side: If the longer side of the image exceeds this value after
|
213 |
+
resizing, the image is resized such that the longer side now equals to
|
214 |
+
this value.
|
215 |
+
jitter: A list of floats containing minimum and maximum size for scale
|
216 |
+
jittering. If available, the shorter side of the image will be
|
217 |
+
resized to a random value in this range.
|
218 |
+
stride: The stride of the smallest feature map in the feature pyramid.
|
219 |
+
Can be calculated using `image_size / feature_map_size`.
|
220 |
+
|
221 |
+
Returns:
|
222 |
+
image: Resized and padded image.
|
223 |
+
image_shape: Shape of the image before padding.
|
224 |
+
ratio: The scaling factor used to resize the image
|
225 |
+
"""
|
226 |
+
image_shape = tf.cast(tf.shape(image)[:2], dtype=tf.float32)
|
227 |
+
if jitter is not None:
|
228 |
+
min_side = tf.random.uniform((), jitter[0], jitter[1], dtype=tf.float32)
|
229 |
+
ratio = min_side / tf.reduce_min(image_shape)
|
230 |
+
if ratio * tf.reduce_max(image_shape) > max_side:
|
231 |
+
ratio = max_side / tf.reduce_max(image_shape)
|
232 |
+
image_shape = ratio * image_shape
|
233 |
+
image = tf.image.resize(image, tf.cast(image_shape, dtype=tf.int32))
|
234 |
+
padded_image_shape = tf.cast(
|
235 |
+
tf.math.ceil(image_shape / stride) * stride, dtype=tf.int32
|
236 |
+
)
|
237 |
+
image = tf.image.pad_to_bounding_box(
|
238 |
+
image, 0, 0, padded_image_shape[0], padded_image_shape[1]
|
239 |
+
)
|
240 |
+
return image, image_shape, ratio
|
241 |
+
|
242 |
+
def visualize_detections(
|
243 |
+
image, boxes, classes, scores, figsize=(7, 7), linewidth=1, color=[0, 0, 1]
|
244 |
+
):
|
245 |
+
"""Visualize Detections"""
|
246 |
+
image = np.array(image, dtype=np.uint8)
|
247 |
+
plt.figure(figsize=figsize)
|
248 |
+
plt.axis("off")
|
249 |
+
plt.imshow(image)
|
250 |
+
ax = plt.gca()
|
251 |
+
for box, _cls, score in zip(boxes, classes, scores):
|
252 |
+
text = "{}: {:.2f}".format(_cls, score)
|
253 |
+
x1, y1, x2, y2 = box
|
254 |
+
w, h = x2 - x1, y2 - y1
|
255 |
+
patch = plt.Rectangle(
|
256 |
+
[x1, y1], w, h, fill=False, edgecolor=color, linewidth=linewidth
|
257 |
+
)
|
258 |
+
ax.add_patch(patch)
|
259 |
+
ax.text(
|
260 |
+
x1,
|
261 |
+
y1,
|
262 |
+
text,
|
263 |
+
bbox={"facecolor": color, "alpha": 0.4},
|
264 |
+
clip_box=ax.clipbox,
|
265 |
+
clip_on=True,
|
266 |
+
)
|
267 |
+
plt.show()
|
268 |
+
return ax
|
269 |
+
|
270 |
+
def prepare_image(image):
|
271 |
+
image, _, ratio = resize_and_pad_image(image, jitter=None)
|
272 |
+
image = tf.keras.applications.resnet.preprocess_input(image)
|
273 |
+
return tf.expand_dims(image, axis=0), ratio
|
274 |
+
|
275 |
+
model = from_pretrained_keras("keras-io/Object-Detection-RetinaNet")
|
276 |
+
img_input = tf.keras.Input(shape=[None, None, 3], name="image")
|
277 |
+
predictions = model(img_input, training=False)
|
278 |
+
detections = DecodePredictions(confidence_threshold=0.5)(img_input, predictions)
|
279 |
+
inference_model = tf.keras.Model(inputs=img_input, outputs=detections)
|
280 |
+
|
281 |
+
def predict(image):
|
282 |
+
input_image, ratio = prepare_image(image)
|
283 |
+
detections = inference_model.predict(input_image)
|
284 |
+
num_detections = detections.valid_detections[0]
|
285 |
+
class_names = [
|
286 |
+
int2str(int(x)) for x in detections.nmsed_classes[0][:num_detections]
|
287 |
+
]
|
288 |
+
img_buf = io.BytesIO()
|
289 |
+
ax = visualize_detections(
|
290 |
+
image,
|
291 |
+
detections.nmsed_boxes[0][:num_detections] / ratio,
|
292 |
+
class_names,
|
293 |
+
detections.nmsed_scores[0][:num_detections],
|
294 |
+
)
|
295 |
+
ax.figure.savefig(img_buf)
|
296 |
+
img_buf.seek(0)
|
297 |
+
img = Image.open(img_buf)
|
298 |
+
return img
|
299 |
+
|
300 |
+
# Input
|
301 |
+
input = gr.inputs.Image(image_mode="RGB", type="numpy", label="Enter Object Image")
|
302 |
+
|
303 |
+
# Output
|
304 |
+
output = gr.outputs.Image(type="pil", label="Detected Objects with Class Category")
|
305 |
+
|
306 |
+
title = "Object Detection With RetinaNet"
|
307 |
+
description = "Upload an Image or take one from examples to localize objects present in an image, and at the same time, classify them into different categories"
|
308 |
+
|
309 |
+
gr.Interface(fn=predict, inputs = input, outputs = output, examples=coco_image, allow_flagging=False, analytics_enabled=False, title=title, description=description).launch(enable_queue=True, debug=True)
|