brxerq commited on
Commit
2230f78
1 Parent(s): cd3bb2c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +82 -182
app.py CHANGED
@@ -1,194 +1,94 @@
 
1
  import os
2
- import cv2
3
  import numpy as np
4
- import importlib.util
5
- import gradio as gr
6
-
7
- # Function to load the model
8
- def load_model(modeldir, graph, labels, edgetpu):
9
- pkg = importlib.util.find_spec('tflite_runtime')
10
- if pkg:
11
- from tflite_runtime.interpreter import Interpreter
12
- if edgetpu:
13
- from tflite_runtime.interpreter import load_delegate
14
- else:
15
- from tensorflow.lite.python.interpreter import Interpreter
16
- if edgetpu:
17
- from tensorflow.lite.python.interpreter import load_delegate
18
-
19
- if edgetpu and (graph == 'detect.tflite'):
20
- graph = 'edgetpu.tflite'
21
-
22
- PATH_TO_CKPT = os.path.join(modeldir, graph)
23
- PATH_TO_LABELS = os.path.join(modeldir, labels)
24
-
25
- with open(PATH_TO_LABELS, 'r') as f:
26
- labels = [line.strip() for line in f.readlines()]
27
-
28
- if labels[0] == '???':
29
- del(labels[0])
30
-
31
- if edgetpu:
32
- interpreter = Interpreter(model_path=PATH_TO_CKPT,
33
- experimental_delegates=[load_delegate('libedgetpu.so.1.0')])
34
- else:
35
- interpreter = Interpreter(model_path=PATH_TO_CKPT)
36
-
37
- interpreter.allocate_tensors()
38
-
39
- return interpreter, labels
40
-
41
- # Function to detect objects
42
- def detect_objects(interpreter, labels, input_data, min_conf_threshold):
43
- input_details = interpreter.get_input_details()
44
- output_details = interpreter.get_output_details()
45
- height = input_details[0]['shape'][1]
46
- width = input_details[0]['shape'][2]
47
-
48
- floating_model = (input_details[0]['dtype'] == np.float32)
49
- input_mean = 127.5
50
- input_std = 127.5
51
-
52
- if floating_model:
53
- input_data = (np.float32(input_data) - input_mean) / input_std
54
 
 
55
  interpreter.set_tensor(input_details[0]['index'], input_data)
56
  interpreter.invoke()
57
 
58
- boxes = interpreter.get_tensor(output_details[0]['index'])[0]
59
- classes = interpreter.get_tensor(output_details[1]['index'])[0]
60
- scores = interpreter.get_tensor(output_details[2]['index'])[0]
 
61
 
62
- detections = []
63
  for i in range(len(scores)):
64
- if ((scores[i] > min_conf_threshold) and (scores[i] <= 1.0)):
65
- detections.append({
66
- 'class': labels[int(classes[i])],
67
- 'score': scores[i],
68
- 'bbox': boxes[i]
69
- })
70
- return detections
71
-
72
- # Function to process images
73
- def process_image(image, interpreter, labels, min_conf_threshold):
74
- image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
75
- imH, imW, _ = image.shape
76
- input_details = interpreter.get_input_details()
77
- height = input_details[0]['shape'][1]
78
- width = input_details[0]['shape'][2]
79
- image_resized = cv2.resize(image_rgb, (width, height))
80
- input_data = np.expand_dims(image_resized, axis=0)
81
 
82
- detections = detect_objects(interpreter, labels, input_data, min_conf_threshold)
 
 
 
 
 
83
 
84
- for detection in detections:
85
- ymin = int(max(1, (detection['bbox'][0] * imH)))
86
- xmin = int(max(1, (detection['bbox'][1] * imW)))
87
- ymax = int(min(imH, (detection['bbox'][2] * imH)))
88
- xmax = int(min(imW, (detection['bbox'][3] * imW)))
89
- cv2.rectangle(image, (xmin, ymin), (xmax, ymax), (10, 255, 0), 2)
90
- label = '%s: %d%%' % (detection['class'], int(detection['score']*100))
91
- labelSize, baseLine = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.7, 2)
92
- label_ymin = max(ymin, labelSize[1] + 10)
93
- cv2.rectangle(image, (xmin, label_ymin-labelSize[1]-10), (xmin+labelSize[0], label_ymin+baseLine-10), (255, 255, 255), cv2.FILLED)
94
- cv2.putText(image, label, (xmin, label_ymin-7), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 2)
95
 
96
- return image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
- # Function to process videos
99
- def process_video(video_path, interpreter, labels, min_conf_threshold):
100
- video = cv2.VideoCapture(video_path)
101
- imW = video.get(cv2.CAP_PROP_FRAME_WIDTH)
102
- imH = video.get(cv2.CAP_PROP_FRAME_HEIGHT)
103
- output_path = "output_" + os.path.basename(video_path)
104
- out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), 20, (int(imW), int(imH)))
105
-
106
- while video.isOpened():
107
- ret, frame = video.read()
108
- if not ret:
109
- break
110
-
111
- image_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
112
- input_details = interpreter.get_input_details()
113
- height = input_details[0]['shape'][1]
114
- width = input_details[0]['shape'][2]
115
- image_resized = cv2.resize(image_rgb, (width, height))
116
- input_data = np.expand_dims(image_resized, axis=0)
117
-
118
- detections = detect_objects(interpreter, labels, input_data, min_conf_threshold)
119
-
120
- for detection in detections:
121
- ymin = int(max(1, (detection['bbox'][0] * imH)))
122
- xmin = int(max(1, (detection['bbox'][1] * imW)))
123
- ymax = int(min(imH, (detection['bbox'][2] * imH)))
124
- xmax = int(min(imW, (detection['bbox'][3] * imW)))
125
- cv2.rectangle(frame, (xmin, ymin), (xmax, ymax), (10, 255, 0), 2)
126
- label = '%s: %d%%' % (detection['class'], int(detection['score']*100))
127
- labelSize, baseLine = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.7, 2)
128
- label_ymin = max(ymin, labelSize[1] + 10)
129
- cv2.rectangle(frame, (xmin, label_ymin-labelSize[1]-10), (xmin+labelSize[0], label_ymin+baseLine-10), (255, 255, 255), cv2.FILLED)
130
- cv2.putText(frame, label, (xmin, label_ymin-7), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 2)
131
-
132
- out.write(frame)
133
-
134
- video.release()
135
- out.release()
136
- return output_path
137
-
138
- # Gradio interface
139
- def predict_image(image, modeldir, graph, labels, threshold, edgetpu):
140
- interpreter, labels = load_model(modeldir, graph, labels, edgetpu)
141
- min_conf_threshold = float(threshold)
142
- result_image = process_image(image, interpreter, labels, min_conf_threshold)
143
- return result_image
144
-
145
- def predict_video(video, modeldir, graph, labels, threshold, edgetpu):
146
- video_path = "temp_video.mp4"
147
- with open(video_path, "wb") as f:
148
- f.write(video.read())
149
- interpreter, labels = load_model(modeldir, graph, labels, edgetpu)
150
- min_conf_threshold = float(threshold)
151
- output_path = process_video(video_path, interpreter, labels, min_conf_threshold)
152
- with open(output_path, "rb") as f:
153
- return f.read()
154
-
155
- iface = gr.Blocks()
156
-
157
- with iface:
158
- gr.Markdown("# Object Detection")
159
- gr.Markdown("Upload an image or a video to detect objects using a TFLite model.")
160
-
161
- with gr.Tabs():
162
- with gr.TabItem("Image Detection"):
163
- img_input = gr.Image(type="numpy", label="Upload an Image")
164
- model_dir = gr.Textbox(label="Model Directory", value="model/")
165
- graph_name = gr.Textbox(label="Graph Name", value="detect.tflite")
166
- labels_name = gr.Textbox(label="Labels Name", value="labelmap.txt")
167
- threshold = gr.Slider(label="Confidence Threshold", minimum=0, maximum=1, value=0.5)
168
- edgetpu = gr.Checkbox(label="Use Edge TPU")
169
- img_output = gr.Image(type="numpy", label="Detected Image")
170
- img_submit = gr.Button("Submit")
171
- img_submit.click(
172
- predict_image,
173
- inputs=[img_input, model_dir, graph_name, labels_name, threshold, edgetpu],
174
- outputs=img_output,
175
- show_progress=True
176
- )
177
-
178
- with gr.TabItem("Video Detection"):
179
- video_input = gr.Video(type="file", label="Upload a Video")
180
- model_dir = gr.Textbox(label="Model Directory", value="model/")
181
- graph_name = gr.Textbox(label="Graph Name", value="detect.tflite")
182
- labels_name = gr.Textbox(label="Labels Name", value="labelmap.txt")
183
- threshold = gr.Slider(label="Confidence Threshold", minimum=0, maximum=1, value=0.5)
184
- edgetpu = gr.Checkbox(label="Use Edge TPU")
185
- video_output = gr.Video(label="Detected Video")
186
- video_submit = gr.Button("Submit")
187
- video_submit.click(
188
- predict_video,
189
- inputs=[video_input, model_dir, graph_name, labels_name, threshold, edgetpu],
190
- outputs=video_output,
191
- show_progress=True
192
- )
193
-
194
- iface.launch()
 
1
+ import streamlit as st
2
  import os
 
3
  import numpy as np
4
+ import cv2
5
+ from PIL import Image
6
+ import tempfile
7
+
8
+ # TensorFlow imports
9
+ from tensorflow.lite.python.interpreter import Interpreter
10
+ if use_TPU:
11
+ from tensorflow.lite.python.interpreter import load_delegate
12
+
13
+ # Setup the model and labels
14
+ MODEL_NAME = 'model'
15
+ GRAPH_NAME = 'detect.tflite'
16
+ LABELMAP_NAME = 'labelmap.txt'
17
+ min_conf_threshold = 0.5
18
+ use_TPU = False # Change this based on your needs
19
+
20
+ PATH_TO_CKPT = os.path.join('model', GRAPH_NAME)
21
+ PATH_TO_LABELS = os.path.join('model', LABELMAP_NAME)
22
+
23
+ # Load labels
24
+ with open(PATH_TO_LABELS, 'r') as f:
25
+ labels = [line.strip() for line in f.readlines()]
26
+ if labels[0] == '???':
27
+ del(labels[0])
28
+
29
+ # Load model
30
+ interpreter = Interpreter(model_path=PATH_TO_CKPT)
31
+ interpreter.allocate_tensors()
32
+
33
+ input_details = interpreter.get_input_details()
34
+ output_details = interpreter.get_output_details()
35
+ height = input_details[0]['shape'][1]
36
+ width = input_details[0]['shape'][2]
37
+
38
+ # Streamlit interface
39
+ st.title('Object Detection System')
40
+ st.sidebar.title('Settings')
41
+ uploaded_file = st.sidebar.file_uploader("Choose an image or video file", type=['jpg', 'png', 'jpeg', 'mp4'])
42
+
43
+ def detect_objects(image):
44
+ # Prepare image for detection
45
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
46
+ image_resized = cv2.resize(image_rgb, (width, height))
47
+ input_data = np.expand_dims(image_resized, axis=0)
48
+ input_data = (np.float32(input_data) - 127.5) / 127.5 # Normalize
 
 
 
 
 
49
 
50
+ # Perform detection
51
  interpreter.set_tensor(input_details[0]['index'], input_data)
52
  interpreter.invoke()
53
 
54
+ # Retrieve detection results
55
+ boxes = interpreter.get_tensor(output_details[0]['index'])[0] # Bounding box coordinates of detected objects
56
+ classes = interpreter.get_tensor(output_details[1]['index'])[0] # Class index of detected objects
57
+ scores = interpreter.get_tensor(output_details[2]['index'])[0] # Confidence of detected objects
58
 
 
59
  for i in range(len(scores)):
60
+ if scores[i] > min_conf_threshold and scores[i] <= 1.0:
61
+ # Draw bounding boxes and labels on the image
62
+ ymin, xmin, ymax, xmax = boxes[i]
63
+ (left, right, top, bottom) = (xmin * imW, xmax * imW, ymin * imH, ymax * imH)
64
+ cv2.rectangle(image, (int(left), int(top)), (int(right), int(bottom)), (10, 255, 0), 4)
65
+ object_name = labels[int(classes[i])]
66
+ label = '%s: %d%%' % (object_name, int(scores[i]*100))
67
+ cv2.putText(image, label, (int(left), int(top)-10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2)
68
+ return image
 
 
 
 
 
 
 
 
69
 
70
+ if uploaded_file is not None:
71
+ file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8)
72
+ if uploaded_file.type == "video/mp4":
73
+ # Handle video upload
74
+ tfile = tempfile.NamedTemporaryFile(delete=False)
75
+ tfile.write(uploaded_file.read())
76
 
77
+ cap = cv2.VideoCapture(tfile.name)
 
 
 
 
 
 
 
 
 
 
78
 
79
+ stframe = st.empty()
80
+
81
+ while cap.isOpened():
82
+ ret, frame = cap.read()
83
+ if not ret:
84
+ break
85
+ frame = detect_objects(frame)
86
+ frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
87
+ stframe.image(frame)
88
+ else:
89
+ # Handle image upload
90
+ image = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
91
+ image = detect_objects(image)
92
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
93
+ st.image(image, use_column_width=True)
94