|
import streamlit as st |
|
import os |
|
from PIL import Image |
|
import torch |
|
import cv2 |
|
from pathlib import Path |
|
from detect_dual import run as yolo_run_detection |
|
|
|
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' |
|
|
|
def add_logo(logo_path, size=(200, 150)): |
|
logo = Image.open(logo_path) |
|
logo = logo.resize(size) |
|
st.image(logo, use_column_width=False) |
|
|
|
def run_detection(source_path): |
|
output_dir = Path("runs/detect/exp") |
|
yolo_run_detection( |
|
weights="models/detect/yolo9trGPR.pt", |
|
source=source_path, |
|
imgsz=(640, 640), |
|
conf_thres=0.25, |
|
iou_thres=0.45, |
|
max_det=1000, |
|
device='', |
|
view_img=False, |
|
save_txt=False, |
|
save_conf=False, |
|
save_crop=False, |
|
nosave=False, |
|
classes=None, |
|
agnostic_nms=False, |
|
augment=False, |
|
visualize=False, |
|
update=False, |
|
project=output_dir.parent, |
|
name=output_dir.name, |
|
exist_ok=True, |
|
line_thickness=3, |
|
hide_labels=False, |
|
hide_conf=False, |
|
half=False, |
|
dnn=False, |
|
vid_stride=1, |
|
) |
|
output_path = output_dir / Path(source_path).name |
|
return str(output_path) |
|
|
|
def process_video(video_path): |
|
cap = cv2.VideoCapture(video_path) |
|
if not cap.isOpened(): |
|
raise ValueError(f"Unable to open video file: {video_path}") |
|
|
|
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
|
|
for i in range(0, frame_count, 10): |
|
cap.set(cv2.CAP_PROP_POS_FRAMES, i) |
|
ret, frame = cap.read() |
|
if ret: |
|
frame_path = f"temp_frame_{i}.jpg" |
|
cv2.imwrite(frame_path, frame) |
|
output_frame = run_detection(frame_path) |
|
yield output_frame |
|
os.remove(frame_path) |
|
else: |
|
break |
|
|
|
cap.release() |
|
|
|
def main(): |
|
st.title("YOLO9tr GPR detection") |
|
|
|
add_logo("logo_ai.jpg") |
|
|
|
source_type = st.radio("Select source type:", ("Image", "Video")) |
|
|
|
if source_type == "Image": |
|
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) |
|
if uploaded_file is not None: |
|
source_path = "temp_image.jpg" |
|
with open(source_path, "wb") as f: |
|
f.write(uploaded_file.getbuffer()) |
|
else: |
|
source_path = "GPR_example.jpg" |
|
|
|
st.image(source_path, caption="Image for Detection", use_column_width=True) |
|
|
|
if st.button("Run Detection"): |
|
with st.spinner("Running detection..."): |
|
output_path = run_detection(source_path) |
|
st.image(output_path, caption="Detection Result", use_column_width=True) |
|
|
|
elif source_type == "Video": |
|
uploaded_file = st.file_uploader("Choose a video...", type=["mp4", "avi", "mov"]) |
|
if uploaded_file is not None: |
|
source_path = "temp_video.mp4" |
|
with open(source_path, "wb") as f: |
|
f.write(uploaded_file.getbuffer()) |
|
|
|
if st.button("Run Detection"): |
|
try: |
|
with st.spinner("Running detection..."): |
|
output_frames = process_video(source_path) |
|
result_placeholder = st.empty() |
|
for frame in output_frames: |
|
result_placeholder.image(frame, caption="Detection Result", use_column_width=True) |
|
except Exception as e: |
|
st.error(f"An error occurred: {str(e)}") |
|
finally: |
|
if os.path.exists(source_path): |
|
os.remove(source_path) |
|
|
|
if __name__ == "__main__": |
|
main() |