gghsgn commited on
Commit
ecca034
1 Parent(s): 7ef831c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +261 -0
app.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #import semua library yang dibutuhkan
2
+ import os
3
+ import io
4
+ import cv2
5
+ import numpy as np
6
+ from PIL import Image, ImageDraw
7
+ from transformers import AutoImageProcessor, AutoModelForObjectDetection
8
+ import streamlit as st
9
+ import torch
10
+ import time
11
+ import pandas as pd
12
+
13
+ # Setting page layout
14
+ st.set_page_config(
15
+ page_title="YoloS Helmet Detection",
16
+ page_icon="🤗",
17
+ layout="wide",
18
+ initial_sidebar_state="expanded"
19
+ )
20
+
21
+ def input_image_setup(uploaded_file):
22
+ if uploaded_file is not None:
23
+ bytes_data = uploaded_file.getvalue()
24
+ image = Image.open(io.BytesIO(bytes_data)) # Convert bytes data to PIL image
25
+ return image
26
+ else:
27
+ raise FileNotFoundError("No file uploaded")
28
+
29
+ # Function to convert OpenCV image to PIL image
30
+ def cv2_to_pil(image):
31
+ return Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
32
+
33
+ def draw_bounding_boxes(image, results, model, confidence):
34
+ draw = ImageDraw.Draw(image)
35
+ for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
36
+ if score.item() >= confidence:
37
+ box = [int(i) for i in box.tolist()]
38
+ draw.rectangle(box, outline="purple", width=2)
39
+ label_text = f"{model.config.id2label[label.item()]} ({round(score.item(), 2)})"
40
+ draw.text((box[0], box[1]), label_text, fill="white")
41
+ return image
42
+
43
+ def process_image(image, model, processor, confidence):
44
+ inputs = processor(images=image, return_tensors="pt")
45
+ outputs = model(**inputs)
46
+
47
+ target_sizes = torch.tensor([image.size[::-1]])
48
+ results = processor.post_process_object_detection(outputs, threshold=confidence, target_sizes=target_sizes)[0]
49
+ return results
50
+
51
+ def detection_results_to_dict(results, model):
52
+ detection_dict = {
53
+ "objects": []
54
+ }
55
+ for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
56
+ box = [round(i, 2) for i in box.tolist()]
57
+ detection_dict["objects"].append({
58
+ "label": model.config.id2label[label.item()],
59
+ "confidence": round(score.item(), 3),
60
+ "box": box
61
+ })
62
+ return detection_dict
63
+
64
+ def convert_dict_to_csv(detection_dict_list):
65
+ combined_results = []
66
+ for detection_dict in detection_dict_list:
67
+ combined_results.extend(detection_dict["objects"])
68
+ df = pd.DataFrame(combined_results)
69
+ return df.to_csv(index=False).encode('utf-8')
70
+
71
+ def clear_detection_results():
72
+ st.session_state.detection_dict_list = []
73
+
74
+ # Initialize session state to store detection results
75
+ if 'detection_dict_list' not in st.session_state:
76
+ st.session_state.detection_dict_list = []
77
+
78
+ # Streamlit App Configuration
79
+ st.header("Helmet Rider Detection")
80
+
81
+ # Sidebar for Model Selection and Confidence Slider
82
+ st.sidebar.header("ML Model Config")
83
+ models = ["gghsgn/final200" ,"gghsgn/final100", "/gghsgn/final50"]
84
+
85
+ model_name = st.sidebar.selectbox("Select model", models)
86
+ confidence = st.sidebar.slider("Select Model Confidence", 25, 100, 40, step=5) / 100
87
+
88
+ # Load Model and Processor
89
+ processor = AutoImageProcessor.from_pretrained(model_name)
90
+ model = AutoModelForObjectDetection.from_pretrained(model_name)
91
+
92
+ # Option to select Real-Time or Upload Image
93
+ mode = st.sidebar.selectbox("Select Input Mode", ["Upload Image", "Real-Time Webcam", "RTSP Video"])
94
+
95
+ # Option if select Upload Image
96
+ if mode == "Upload Image":
97
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
98
+ if uploaded_file is not None:
99
+ image = Image.open(uploaded_file)
100
+ st.image(image, caption="Uploaded Image.", use_column_width=True)
101
+ else:
102
+ image = None
103
+ submit = st.button("Detect Objects")
104
+ if submit and image is not None:
105
+ try:
106
+ image_data = input_image_setup(uploaded_file)
107
+ st.subheader("The response is..")
108
+
109
+ results = process_image(image, model, processor, confidence)
110
+ drawn_image = draw_bounding_boxes(image.copy(), results, model, confidence)
111
+ st.image(drawn_image, caption="Detected Objects", use_column_width=True)
112
+
113
+ st.subheader("List of Objects:")
114
+ for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
115
+ box = [round(i, 2) for i in box.tolist()]
116
+ st.write(
117
+ f"Detected :orange[{model.config.id2label[label.item()]}] with confidence "
118
+ f":green[{round(score.item(), 3)}] at location :violet[{box}]"
119
+ )
120
+
121
+ detected_objects = {model.config.id2label[label.item()]: 0 for label in results["labels"]}
122
+ for label in results["labels"]:
123
+ detected_objects[model.config.id2label[label.item()]] += 1
124
+ for obj, count in detected_objects.items():
125
+ st.write(f"Class :orange[{obj}] detected {count} time(s)")
126
+
127
+ detection_dict = detection_results_to_dict(results, model)
128
+ #st.write(detection_dict)
129
+
130
+ csv_data = convert_dict_to_csv([detection_dict])
131
+ st.download_button(
132
+ label="Download Results as CSV",
133
+ data=csv_data,
134
+ file_name="detection_results.csv",
135
+ mime="text/csv"
136
+ )
137
+
138
+ except Exception as e:
139
+ st.error(f"Error: {e}")
140
+
141
+ elif submit and image is None:
142
+ st.error("Please upload an image before trying to detect objects.")
143
+
144
+ # Option if select Realtime Webcam
145
+ elif mode == "Real-Time Webcam":
146
+ run = st.checkbox("Run Webcam")
147
+ FRAME_WINDOW = st.image([])
148
+
149
+ if run:
150
+ cap = cv2.VideoCapture(0)
151
+ while run:
152
+ ret, frame = cap.read()
153
+ if not ret:
154
+ st.error("Failed to capture image from webcam")
155
+ break
156
+
157
+ frame_pil = cv2_to_pil(frame)
158
+
159
+ try:
160
+ results = process_image(frame_pil, model, processor, confidence)
161
+ drawn_image = draw_bounding_boxes(frame_pil.copy(), results, model, confidence)
162
+ FRAME_WINDOW.image(drawn_image, caption="Detected Objects", use_column_width=True)
163
+
164
+ st.subheader("List of Objects:")
165
+ for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
166
+ box = [round(i, 2) for i in box.tolist()]
167
+ st.write(
168
+ f"Detected :orange[{model.config.id2label[label.item()]}] with confidence "
169
+ f":green[{round(score.item(), 3)}] at location :violet[{box}]"
170
+ )
171
+
172
+ detected_objects = {model.config.id2label[label.item()]: 0 for label in results["labels"]}
173
+ for label in results["labels"]:
174
+ detected_objects[model.config.id2label[label.item()]] += 1
175
+ for obj, count in detected_objects.items():
176
+ st.write(f"Class :orange[{obj}] detected {count} time(s)")
177
+
178
+ detection_dict = detection_results_to_dict(results, model)
179
+ st.session_state.detection_dict_list.append(detection_dict)
180
+ #st.write(detection_dict)
181
+
182
+
183
+ except Exception as e:
184
+ st.error(f"Error: {e}")
185
+
186
+ time.sleep(0.1) # Delay for the next frame capture to create an illusion of real-time
187
+ cap.release()
188
+
189
+ if not run and st.session_state.detection_dict_list:
190
+ st.write("Detection stopped.")
191
+ csv_data = convert_dict_to_csv(st.session_state.detection_dict_list)
192
+ st.download_button(
193
+ label="Download All Results as CSV",
194
+ data=csv_data,
195
+ file_name="all_detection_results.csv",
196
+ mime="text/csv"
197
+ )
198
+ st.button("Clear Results", on_click=clear_detection_results)
199
+
200
+ # Option if select Realtime RTSP Video
201
+ elif mode == "RTSP Video":
202
+ rtsp_url = st.text_input("RTSP URL")
203
+ run = st.checkbox("Run RTSP Video")
204
+ FRAME_WINDOW = st.image([])
205
+
206
+ if rtsp_url and run:
207
+ cap = cv2.VideoCapture(rtsp_url)
208
+ st.subheader("List of Objects:")
209
+ while run:
210
+ ret, frame = cap.read()
211
+ if not ret:
212
+ st.error("Failed to capture image from RTSP stream")
213
+ break
214
+
215
+ frame_pil = cv2_to_pil(frame)
216
+
217
+ try:
218
+ results = process_image(frame_pil, model, processor, confidence)
219
+ drawn_image = draw_bounding_boxes(frame_pil.copy(), results, model, confidence)
220
+ FRAME_WINDOW.image(drawn_image, caption="Detected Objects", use_column_width=True)
221
+ st.subheader("List of Objects:")
222
+ for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
223
+ box = [round(i, 2) for i in box.tolist()]
224
+ st.write(
225
+ f"Detected :orange[{model.config.id2label[label.item()]}] with confidence "
226
+ f":green[{round(score.item(), 3)}] at location :violet[{box}]"
227
+ )
228
+
229
+ detected_objects = {model.config.id2label[label.item()]: 0 for label in results["labels"]}
230
+ for label in results["labels"]:
231
+ detected_objects[model.config.id2label[label.item()]] += 1
232
+ for obj, count in detected_objects.items():
233
+ st.write(f"Class :orange[{obj}] detected {count} time(s)")
234
+
235
+ detection_dict = detection_results_to_dict(results, model)
236
+ st.session_state.detection_dict_list.append(detection_dict)
237
+ #st.write(detection_dict)
238
+
239
+ except Exception as e:
240
+ st.error(f"Error: {e}")
241
+
242
+ time.sleep(0.1) # Delay for the next frame capture to create an illusion of real-time
243
+
244
+ cap.release()
245
+
246
+ if not run and st.session_state.detection_dict_list:
247
+ st.write("Detection stopped.")
248
+ csv_data = convert_dict_to_csv(st.session_state.detection_dict_list)
249
+ st.download_button(
250
+ label="Download All Results as CSV",
251
+ data=csv_data,
252
+ file_name="all_detection_results.csv",
253
+ mime="text/csv"
254
+ )
255
+ st.button("Clear Results", on_click=clear_detection_results)
256
+ elif run:
257
+ st.error("Please provide a valid RTSP URL before running the stream.")
258
+
259
+ # Ensure the video capture object is released if the checkbox is unchecked
260
+ if 'cap' in locals():
261
+ cap.release()