Jaime García Villena commited on
Commit
7cd0692
1 Parent(s): 9684c94

add a way to test this tflite

Browse files
Files changed (3) hide show
  1. coco_labels.json +82 -0
  2. test_images/cat.jpg +0 -0
  3. test_yolob8_tflite.py +205 -0
coco_labels.json ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "0": "person",
3
+ "1": "bicycle",
4
+ "2": "car",
5
+ "3": "motorcycle",
6
+ "4": "airplane",
7
+ "5": "bus",
8
+ "6": "train",
9
+ "7": "truck",
10
+ "8": "boat",
11
+ "9": "traffic light",
12
+ "10": "fire hydrant",
13
+ "11": "stop sign",
14
+ "12": "parking meter",
15
+ "13": "bench",
16
+ "14": "bird",
17
+ "15": "cat",
18
+ "16": "dog",
19
+ "17": "horse",
20
+ "18": "sheep",
21
+ "19": "cow",
22
+ "20": "elephant",
23
+ "21": "bear",
24
+ "22": "zebra",
25
+ "23": "giraffe",
26
+ "24": "backpack",
27
+ "25": "umbrella",
28
+ "26": "handbag",
29
+ "27": "tie",
30
+ "28": "suitcase",
31
+ "29": "frisbee",
32
+ "30": "skis",
33
+ "31": "snowboard",
34
+ "32": "sports ball",
35
+ "33": "kite",
36
+ "34": "baseball bat",
37
+ "35": "baseball glove",
38
+ "36": "skateboard",
39
+ "37": "surfboard",
40
+ "38": "tennis racket",
41
+ "39": "bottle",
42
+ "40": "wine glass",
43
+ "41": "cup",
44
+ "42": "fork",
45
+ "43": "knife",
46
+ "44": "spoon",
47
+ "45": "bowl",
48
+ "46": "banana",
49
+ "47": "apple",
50
+ "48": "sandwich",
51
+ "49": "orange",
52
+ "50": "broccoli",
53
+ "51": "carrot",
54
+ "52": "hot dog",
55
+ "53": "pizza",
56
+ "54": "donut",
57
+ "55": "cake",
58
+ "56": "chair",
59
+ "57": "couch",
60
+ "58": "potted plant",
61
+ "59": "bed",
62
+ "60": "dining table",
63
+ "61": "toilet",
64
+ "62": "tv",
65
+ "63": "laptop",
66
+ "64": "mouse",
67
+ "65": "remote",
68
+ "66": "keyboard",
69
+ "67": "cell phone",
70
+ "68": "microwave",
71
+ "69": "oven",
72
+ "70": "toaster",
73
+ "71": "sink",
74
+ "72": "refrigerator",
75
+ "73": "book",
76
+ "74": "clock",
77
+ "75": "vase",
78
+ "76": "scissors",
79
+ "77": "teddy bear",
80
+ "78": "hair drier",
81
+ "79": "toothbrush"
82
+ }
test_images/cat.jpg ADDED
test_yolob8_tflite.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Created on Wed Oct 4 16:44:12 2023
5
+
6
+ @author: lin
7
+ """
8
+ import glob
9
+ import sys
10
+ sys.path.append('../../..')
11
+ import os
12
+ import cv2
13
+ import json
14
+ import tensorflow as tf
15
+ import numpy as np
16
+ import matplotlib.pyplot as plt
17
+ # from utils.bbox_op import non_max_supression
18
+
19
+ def one_multiple_iou(box, boxes, box_area, boxes_area):
20
+ """
21
+ Compute the intersection over union. 1 to multiple
22
+ Inputs:
23
+ box: numpy array with 1 box, ymin, xmin, ymax, xmax
24
+ boxes: numpy array with shape [N, 4] holding N boxes
25
+
26
+ Outputs:
27
+ a numpy array with shape [N*1] representing box areas
28
+ """
29
+
30
+ # this is the iou of the box against all other boxes
31
+ assert boxes.shape[0] == boxes_area.shape[0]
32
+
33
+ ymin = np.maximum(box[0], boxes[:, 0]) # bottom
34
+ xmin = np.maximum(box[1], boxes[:, 1]) # left
35
+ ymax = np.minimum(box[2], boxes[:, 2]) # top
36
+ xmax = np.minimum(box[3], boxes[:, 3]) # rifht
37
+
38
+ # we ignore areas where the intersection side would be negative
39
+ # this is done by using maxing the side length by 0
40
+ intersections = np.maximum(ymax - ymin, 0) * np.maximum(xmax - xmin, 0)
41
+ # each union is then the box area
42
+ # added to each other box area minusing their intersection calculated above
43
+ unions = box_area + boxes_area - intersections
44
+ # element wise division
45
+ # if the intersection is 0, then their ratio is 0
46
+ ious = intersections / unions
47
+ return ious
48
+ def select_non_overlapping_bboxes(boxes, scores, iou_th):
49
+ ymin = boxes[:, 0]
50
+ ymax = boxes[:, 2]
51
+ xmin = boxes[:, 1]
52
+ xmax = boxes[:, 3]
53
+
54
+ # box coordinate ranges are inclusive-inclusive
55
+ areas = (ymax - ymin) * (xmax - xmin)
56
+ scores_indexes = list(np.argsort(scores))
57
+ keep_idx = []
58
+ while len(scores_indexes) > 0:
59
+ index = scores_indexes.pop()
60
+ keep_idx.append(index)
61
+
62
+ ious = one_multiple_iou(
63
+ boxes[index], boxes[scores_indexes], areas[index], areas[scores_indexes]
64
+ )
65
+ filtered_indexes = set((ious > iou_th).nonzero()[0])
66
+
67
+ scores_indexes = [
68
+ v for (i, v) in enumerate(scores_indexes) if i not in filtered_indexes
69
+ ]
70
+ return keep_idx
71
+ def non_max_supression(boxes, scores, classes, iou_th):
72
+ """
73
+ remover overlaped boundingboxes. Starting by the box with the highest score
74
+ if the iou is greater than the threshold, remove it, else keep it.
75
+ Inputs:
76
+ boxes: numpy array with shape [N, 4] holding N boxes。 [ymin, xmin, ymax, xmax]
77
+ scores: numpy array with shape [N, 1] holding the prediction score of each box
78
+ classes: numpy array with shape [N, 1] holding the class that each box belongs
79
+ iou_th: intersection over union threshold to consider the overlapping boxes have detect 2 objects
80
+ Output:
81
+ boxes, scores, classes with intersection over union ratio less than the threshold.
82
+
83
+ """
84
+ # assert boxes.shape[0] == scores.shape[0]
85
+ if len(scores) == 0:
86
+ return boxes, scores, classes
87
+ keep_idx = select_non_overlapping_bboxes(boxes, scores, iou_th)
88
+
89
+ return boxes[keep_idx], scores[keep_idx], classes[keep_idx]
90
+
91
+ def preprocess(img_path):
92
+ image_np = cv2.imread(img_path)
93
+
94
+ image_np = center_crop(image_np)
95
+
96
+ image_np = cv2.resize(image_np, (640, 640))
97
+ #image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)
98
+ image_np = image_np.astype(float)
99
+ image_np /= 255.0
100
+ return image_np
101
+
102
+ def center_crop(img):
103
+ width, height = img.shape[1], img.shape[0]
104
+
105
+ crop_size = width if width < height else height
106
+
107
+ mid_x, mid_y = int(width/2), int(height/2)
108
+
109
+ cs2 = int(crop_size/2)
110
+ crop_img = img[mid_y-cs2:mid_y+cs2, mid_x-cs2:mid_x+cs2]
111
+
112
+ return crop_img
113
+
114
+ def postprocess_prediction(preds):
115
+
116
+ bboxes = preds[0][:4]
117
+ class_prob = preds[0, 4:]
118
+ classes = np.argmax(class_prob, axis=0)
119
+ scores = np.max(class_prob, axis=0)
120
+ # filter by threshold
121
+ valid_idx = np.where(scores>=min_th)[0]
122
+ bboxes = bboxes[:, valid_idx]
123
+ classes = classes[valid_idx]
124
+ scores = scores[valid_idx]
125
+ bboxes = bboxes.transpose()
126
+ bboxes = bboxes*640
127
+ xmin = bboxes[:,0]-bboxes[:,2]//2
128
+ xmax = bboxes[:,0]+bboxes[:,2]//2
129
+ ymin = bboxes[:,1]-bboxes[:,3]//2
130
+ ymax = bboxes[:,1]+bboxes[:,3]//2
131
+ xmin = np.clip(xmin, 0, 640)
132
+ ymin = np.clip(ymin, 0, 640)
133
+
134
+ bboxes = np.vstack([ymin, xmin, ymax, xmax])
135
+ bboxes = bboxes.transpose()
136
+ bboxes = bboxes.astype(int)
137
+
138
+ bboxes, scores, classes = non_max_supression(bboxes, scores, classes, iou_th=0.5)
139
+ idx = np.argsort(scores)[::-1]
140
+ bboxes = bboxes[idx]
141
+ classes = classes[idx]
142
+ scores = scores[idx]
143
+ return bboxes, classes, scores
144
+
145
+ def plot_prediction(image_np, bboxes, classes, scores, label_map):
146
+ color=(255,0,0)
147
+ thickness=5
148
+ font_scale=3
149
+
150
+ for i, box in enumerate(bboxes):
151
+ box = bboxes[i, :]
152
+
153
+ ymin, xmin, ymax, xmax = box
154
+
155
+ image_np = cv2.rectangle(image_np, (xmin, ymin), (xmax, ymax), color=color, thickness=thickness)
156
+ text_x = xmin - 10 if xmin > 20 else xmin + 10
157
+ text_y = ymin - 10 if ymin > 20 else ymin + 10
158
+ display_str = label_map[str(classes[i])]
159
+
160
+ cv2.putText(
161
+ image_np,
162
+ display_str,
163
+ (text_x, text_y),
164
+ cv2.FONT_HERSHEY_SIMPLEX,
165
+ font_scale,
166
+ color,
167
+ thickness,
168
+ )
169
+ plt.imshow(image_np)
170
+ plt.show()
171
+
172
+ def predict_yolo_tflite(intenpreter, image_np):
173
+ input_tensor = np.expand_dims(image_np, axis=0)
174
+ input_tensor = tf.convert_to_tensor(input_tensor, dtype=tf.float32)
175
+
176
+ interpreter.set_tensor(input_details[0]['index'], input_tensor.numpy())
177
+
178
+ interpreter.invoke()
179
+ preds = interpreter.get_tensor(output_details[0]['index'])
180
+ return preds
181
+
182
+ if __name__ == "__main__":
183
+ min_th = 0.1
184
+ labels_json = "coco_labels.json"
185
+ with open(labels_json) as f:
186
+ label_map = json.load(f)
187
+ img_path = "test_images"
188
+ saved_tflite = "tflite_model.tflite"
189
+ # load model
190
+ interpreter = tf.lite.Interpreter(model_path=saved_tflite)
191
+ interpreter.allocate_tensors()
192
+ input_details = interpreter.get_input_details()
193
+ output_details = interpreter.get_output_details()
194
+ print(input_details)
195
+ print(output_details)
196
+ images = glob.glob(os.path.join(img_path, "*"))
197
+ for img in images:
198
+ image_np = preprocess(img)
199
+ print(image_np.shape)
200
+
201
+ # image_np = np.array(Image.open(image_path))
202
+ preds = predict_yolo_tflite(interpreter, image_np)
203
+ bboxes, classes, scores = postprocess_prediction(preds)
204
+
205
+ plot_prediction(image_np, bboxes, classes, scores, label_map)