File size: 3,802 Bytes
858c475
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import streamlit as st
import os
from PIL import Image
import torch
import cv2
from pathlib import Path
from detect_dual import run as yolo_run_detection

os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

def add_logo(logo_path, size=(200, 150)):
    logo = Image.open(logo_path)
    logo = logo.resize(size)
    st.image(logo, use_column_width=False)

def run_detection(source_path):
    output_dir = Path("runs/detect/exp")
    yolo_run_detection(
        weights="models/detect/yolo9trGPR.pt",  # Adjust this path to your model weights
        source=source_path,
        imgsz=(640, 640),
        conf_thres=0.25,
        iou_thres=0.45,
        max_det=1000,
        device='',
        view_img=False,
        save_txt=False,
        save_conf=False,
        save_crop=False,
        nosave=False,
        classes=None,
        agnostic_nms=False,
        augment=False,
        visualize=False,
        update=False,
        project=output_dir.parent,
        name=output_dir.name,
        exist_ok=True,
        line_thickness=3,
        hide_labels=False,
        hide_conf=False,
        half=False,
        dnn=False,
        vid_stride=1,
    )
    output_path = output_dir / Path(source_path).name
    return str(output_path)

def process_video(video_path):
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        raise ValueError(f"Unable to open video file: {video_path}")
    
    frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    
    for i in range(0, frame_count, 10):
        cap.set(cv2.CAP_PROP_POS_FRAMES, i)
        ret, frame = cap.read()
        if ret:
            frame_path = f"temp_frame_{i}.jpg"
            cv2.imwrite(frame_path, frame)
            output_frame = run_detection(frame_path)
            yield output_frame
            os.remove(frame_path)  # Clean up temporary frame file
        else:
            break
    
    cap.release()

def main():
    st.title("YOLO9tr GPR detection")
    
    add_logo("logo_ai.jpg")
    
    source_type = st.radio("Select source type:", ("Image", "Video"))
    
    if source_type == "Image":
        uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
        if uploaded_file is not None:
            source_path = "temp_image.jpg"
            with open(source_path, "wb") as f:
                f.write(uploaded_file.getbuffer())
        else:
            source_path = "GPR_example.jpg"  # Default image
        
        st.image(source_path, caption="Image for Detection", use_column_width=True)
        
        if st.button("Run Detection"):
            with st.spinner("Running detection..."):
                output_path = run_detection(source_path)
                st.image(output_path, caption="Detection Result", use_column_width=True)
    
    elif source_type == "Video":
        uploaded_file = st.file_uploader("Choose a video...", type=["mp4", "avi", "mov"])
        if uploaded_file is not None:
            source_path = "temp_video.mp4"
            with open(source_path, "wb") as f:
                f.write(uploaded_file.getbuffer())
            
            if st.button("Run Detection"):
                try:
                    with st.spinner("Running detection..."):
                        output_frames = process_video(source_path)
                        result_placeholder = st.empty()
                        for frame in output_frames:
                            result_placeholder.image(frame, caption="Detection Result", use_column_width=True)
                except Exception as e:
                    st.error(f"An error occurred: {str(e)}")
                finally:
                    if os.path.exists(source_path):
                        os.remove(source_path)  # Clean up temporary video file

if __name__ == "__main__":
    main()