brxerq commited on
Commit
0f10612
1 Parent(s): 89046b8

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +138 -0
app.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+ import importlib.util
5
+ import gradio as gr
6
+ from PIL import Image
7
+
8
+ # Load the TensorFlow Lite model
9
+ MODEL_DIR = 'model'
10
+ GRAPH_NAME = 'detect.tflite'
11
+ LABELMAP_NAME = 'labelmap.txt'
12
+
13
+ pkg = importlib.util.find_spec('tflite_runtime')
14
+ if pkg:
15
+ from tflite_runtime.interpreter import Interpreter
16
+ from tflite_runtime.interpreter import load_delegate
17
+ else:
18
+ from tensorflow.lite.python.interpreter import Interpreter
19
+ from tensorflow.lite.python.interpreter import load_delegate
20
+
21
+ PATH_TO_CKPT = os.path.join(MODEL_DIR, GRAPH_NAME)
22
+ PATH_TO_LABELS = os.path.join(MODEL_DIR, LABELMAP_NAME)
23
+
24
+ # Load the label map
25
+ with open(PATH_TO_LABELS, 'r') as f:
26
+ labels = [line.strip() for line in f.readlines()]
27
+
28
+ if labels[0] == '???':
29
+ del(labels[0])
30
+
31
+ # Load the TensorFlow Lite model
32
+ interpreter = Interpreter(model_path=PATH_TO_CKPT)
33
+ interpreter.allocate_tensors()
34
+
35
+ input_details = interpreter.get_input_details()
36
+ output_details = interpreter.get_output_details()
37
+ height = input_details[0]['shape'][1]
38
+ width = input_details[0]['shape'][2]
39
+ floating_model = (input_details[0]['dtype'] == np.float32)
40
+
41
+ input_mean = 127.5
42
+ input_std = 127.5
43
+
44
+ outname = output_details[0]['name']
45
+ if ('StatefulPartitionedCall' in outname):
46
+ boxes_idx, classes_idx, scores_idx = 1, 3, 0
47
+ else:
48
+ boxes_idx, classes_idx, scores_idx = 0, 1, 2
49
+
50
+ def perform_detection(image, interpreter, labels):
51
+ imH, imW, _ = image.shape
52
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
53
+ image_resized = cv2.resize(image_rgb, (width, height))
54
+ input_data = np.expand_dims(image_resized, axis=0)
55
+
56
+ if floating_model:
57
+ input_data = (np.float32(input_data) - input_mean) / input_std
58
+
59
+ interpreter.set_tensor(input_details[0]['index'], input_data)
60
+ interpreter.invoke()
61
+
62
+ boxes = interpreter.get_tensor(output_details[boxes_idx]['index'])[0]
63
+ classes = interpreter.get_tensor(output_details[classes_idx]['index'])[0]
64
+ scores = interpreter.get_tensor(output_details[scores_idx]['index'])[0]
65
+
66
+ detections = []
67
+ for i in range(len(scores)):
68
+ if ((scores[i] > 0.5) and (scores[i] <= 1.0)):
69
+ ymin = int(max(1, (boxes[i][0] * imH)))
70
+ xmin = int(max(1, (boxes[i][1] * imW)))
71
+ ymax = int(min(imH, (boxes[i][2] * imH)))
72
+ xmax = int(min(imW, (boxes[i][3] * imW)))
73
+
74
+ cv2.rectangle(image, (xmin, ymin), (xmax, ymax), (10, 255, 0), 2)
75
+ object_name = labels[int(classes[i])]
76
+ label = '%s: %d%%' % (object_name, int(scores[i] * 100))
77
+ labelSize, baseLine = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.7, 2)
78
+ label_ymin = max(ymin, labelSize[1] + 10)
79
+ cv2.rectangle(image, (xmin, label_ymin - labelSize[1] - 10), (xmin + labelSize[0], label_ymin + baseLine - 10), (255, 255, 255), cv2.FILLED)
80
+ cv2.putText(image, label, (xmin, label_ymin - 7), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 2)
81
+
82
+ detections.append([object_name, scores[i], xmin, ymin, xmax, ymax])
83
+ return image
84
+
85
+ def detect_image(input_image):
86
+ image = np.array(input_image)
87
+ result_image = perform_detection(image, interpreter, labels)
88
+ return Image.fromarray(result_image)
89
+
90
+ def detect_video(input_video):
91
+ cap = cv2.VideoCapture(input_video.name)
92
+ frames = []
93
+
94
+ while cap.isOpened():
95
+ ret, frame = cap.read()
96
+ if not ret:
97
+ break
98
+
99
+ result_frame = perform_detection(frame, interpreter, labels)
100
+ frames.append(result_frame)
101
+
102
+ cap.release()
103
+
104
+ height, width, layers = frames[0].shape
105
+ size = (width, height)
106
+ output_video_path = "result_" + input_video.name
107
+ out = cv2.VideoWriter(output_video_path, cv2.VideoWriter_fourcc(*'DIVX'), 15, size)
108
+
109
+ for frame in frames:
110
+ out.write(frame)
111
+
112
+ out.release()
113
+
114
+ return output_video_path
115
+
116
+ image_input = gr.inputs.Image(type="pil", label="Upload an image")
117
+ image_output = gr.outputs.Image(type="pil", label="Detection Result")
118
+
119
+ video_input = gr.inputs.Video(type="file", label="Upload a video")
120
+ video_output = gr.outputs.Video(label="Detection Result")
121
+
122
+ app = gr.Interface(
123
+ fn=detect_image,
124
+ inputs=image_input,
125
+ outputs=image_output,
126
+ live=True,
127
+ description="Object Detection on Images"
128
+ )
129
+
130
+ app_video = gr.Interface(
131
+ fn=detect_video,
132
+ inputs=video_input,
133
+ outputs=video_output,
134
+ live=True,
135
+ description="Object Detection on Videos"
136
+ )
137
+
138
+ gr.TabbedInterface([app, app_video], ["Image Detection", "Video Detection"]).launch()