File size: 12,761 Bytes
9de653a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
import streamlit as st
import cv2
import torch
import numpy as np
import time
import tempfile
from pathlib import Path

# Import detection utilities
from detection_utils import load_model, detect_objects, draw_boxes, ObjectTracker

def initialize_video_capture(input_source, video_file=None, url=None):
    """Initialize video capture and writer"""
    cap = None
    out = None
    output_path = None
    
    if input_source == "Video File" and video_file is not None:
        # Save uploaded file to temp location
        tfile = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4')
        tfile.write(video_file.read())
        tfile.flush()
        video_path = tfile.name
        
        # Open video capture
        cap = cv2.VideoCapture(video_path)
        
        if cap.isOpened():
            # Get video properties
            width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
            height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
            fps = int(cap.get(cv2.CAP_PROP_FPS))
            
            # Ensure valid FPS
            if fps <= 0:
                fps = 30
            
            # Create output path in a temporary directory
            temp_dir = tempfile.gettempdir()
            output_path = str(Path(temp_dir) / 'detected_output.mp4')
            
            # Try different codecs in order of preference
            codecs = [
                ('avc1', '.mp4'),
                ('mp4v', '.mp4'),
                ('XVID', '.avi')
            ]
            
            for codec, ext in codecs:
                try:
                    output_path = str(Path(temp_dir) / f'detected_output{ext}')
                    fourcc = cv2.VideoWriter_fourcc(*codec)
                    out = cv2.VideoWriter(
                        output_path,
                        fourcc,
                        fps,
                        (width, height),
                        isColor=True
                    )
                    
                    # Test if writer is working
                    if out.isOpened():
                        break
                except Exception:
                    continue
            
            if out is None or not out.isOpened():
                st.error("Failed to create video writer")
                return None, None, None
    
    elif input_source == "Live Stream URL" and url:
        cap = cv2.VideoCapture(url)
    
    return cap, out, output_path

def get_model_info():
    """Return information about available YOLO models"""
    return {
        'yolov8n.pt': {
            'name': 'YOLOv8 Nano',
            'description': 'Smallest and fastest model. Best for CPU or low-power devices.',
            'speed': '⚡⚡⚡⚡⚡',
            'accuracy': '⭐⭐',
            'size': '6.7 MB',
            'details': 'Ideal for real-time applications with limited computing power.'
        },
        'yolov8s.pt': {
            'name': 'YOLOv8 Small',
            'description': 'Small model balancing speed and accuracy.',
            'speed': '⚡⚡⚡⚡',
            'accuracy': '⭐⭐⭐',
            'size': '22.4 MB',
            'details': 'Good for general purpose detection with decent performance.'
        },
        'yolov8m.pt': {
            'name': 'YOLOv8 Medium',
            'description': 'Medium-sized model with good balance.',
            'speed': '⚡⚡⚡',
            'accuracy': '⭐⭐⭐⭐',
            'size': '52.2 MB',
            'details': 'Recommended for standard detection tasks with good GPU.'
        },
        'yolov8l.pt': {
            'name': 'YOLOv8 Large',
            'description': 'Large model with high accuracy.',
            'speed': '⚡⚡',
            'accuracy': '⭐⭐⭐⭐⭐',
            'size': '87.7 MB',
            'details': 'Best for high-accuracy requirements with good computing power.'
        },
        'yolov8x.pt': {
            'name': 'YOLOv8 XLarge',
            'description': 'Extra large model with highest accuracy.',
            'speed': '⚡',
            'accuracy': '⭐⭐⭐⭐⭐⭐',
            'size': '131.7 MB',
            'details': 'Best for tasks requiring maximum accuracy, requires powerful GPU.'
        }
    }

def main():
    st.title("Real-Time Object Detection")
    
    # Initialize session state
    if 'tracker' not in st.session_state:
        st.session_state.tracker = ObjectTracker()
    if 'cap' not in st.session_state:
        st.session_state.cap = None
    if 'out' not in st.session_state:
        st.session_state.out = None
    if 'output_path' not in st.session_state:
        st.session_state.output_path = None
    if 'processed_frames' not in st.session_state:
        st.session_state.processed_frames = 0
    if 'selected_model' not in st.session_state:
        st.session_state.selected_model = 'yolov8x.pt'
    if 'model' not in st.session_state:
        st.session_state.model = None
    
    # Sidebar settings
    st.sidebar.title("Settings")
    
    # Model selection
    st.sidebar.subheader("Model Selection")
    model_info = get_model_info()
    selected_model = st.sidebar.selectbox(
        "Choose YOLO Model",
        options=list(model_info.keys()),
        format_func=lambda x: model_info[x]['name'],
        index=list(model_info.keys()).index(st.session_state.selected_model)
    )
    
    # Display model information
    with st.sidebar.expander("Model Details", expanded=True):
        st.markdown(f"**{model_info[selected_model]['name']}**")
        st.write(model_info[selected_model]['description'])
        st.write(f"Speed: {model_info[selected_model]['speed']}")
        st.write(f"Accuracy: {model_info[selected_model]['accuracy']}")
        st.write(f"Size: {model_info[selected_model]['size']}")
        st.write(f"Details: {model_info[selected_model]['details']}")
    
    # Add Load Model button
    if st.sidebar.button("Load Selected Model"):
        with st.spinner(f"Loading {model_info[selected_model]['name']}..."):
            st.session_state.model = load_model(selected_model)
            st.session_state.selected_model = selected_model
            st.sidebar.success("Model loaded successfully!")
    
    # Detection confidence
    detection_confidence = st.sidebar.slider("Detection Confidence", 0.0, 1.0, 0.5)
    
    # Input selection
    input_source = st.radio("Select Input Source", ["Video File", "Live Stream URL"])
    
    try:
        # Handle video input
        if input_source == "Video File":
            video_file = st.file_uploader("Upload Video", type=['mp4', 'avi'])
            if video_file is not None:
                st.session_state.cap, st.session_state.out, st.session_state.output_path = initialize_video_capture(input_source, video_file=video_file)
        else:
            url = st.text_input("Enter Stream URL")
            if url:
                st.session_state.cap, st.session_state.out, st.session_state.output_path = initialize_video_capture(input_source, url=url)
        
        if st.session_state.cap is not None and not st.session_state.cap.isOpened():
            st.error("Error: Could not open video source")
            st.stop()
        
        # Create placeholder for video display
        video_placeholder = st.empty()
        
        # Initialize frame buffer in session state
        if 'frame_buffer' not in st.session_state:
            st.session_state.frame_buffer = []
        
        # Control buttons - Move them to sidebar to avoid duplication
        st.sidebar.markdown("---")
        st.sidebar.subheader("Controls")
        start_button = st.sidebar.button("Start Detection")
        stop_button = st.sidebar.button("Stop Detection")
        
        if start_button:
            if st.session_state.model is None:
                st.error("Please load a model first using the 'Load Selected Model' button")
                st.stop()
            if st.session_state.cap is None:
                st.error("Please upload a video or provide a stream URL first")
                st.stop()
            st.session_state.run_detection = True
            st.session_state.processed_frames = 0
            st.session_state.frame_buffer = []  # Clear buffer on start
        if stop_button:
            st.session_state.run_detection = False
        
        # Detection loop
        while (hasattr(st.session_state, 'run_detection') and 
               st.session_state.run_detection and 
               st.session_state.cap is not None):
            
            ret, frame = st.session_state.cap.read()
            if not ret:
                break
            
            # Perform detection
            detections = detect_objects(st.session_state.model, frame, detection_confidence)
            
            # Draw boxes on frame
            annotated_frame = draw_boxes(frame, detections, st.session_state.tracker)
            
            # Add frame to buffer
            st.session_state.frame_buffer.append(annotated_frame)
            
            # Write frames to video periodically
            if len(st.session_state.frame_buffer) >= 30:  # Write every 30 frames
                for buffered_frame in st.session_state.frame_buffer:
                    if st.session_state.out is not None:
                        st.session_state.out.write(buffered_frame)
                        st.session_state.processed_frames += 1
                st.session_state.frame_buffer.clear()
            
            # Update display every 3rd frame
            if st.session_state.processed_frames % 3 == 0:
                video_placeholder.image(annotated_frame, channels="BGR")
            
            # Minimal sleep to prevent UI freezing
            time.sleep(0.001)
        
        # Write remaining frames in buffer
        if st.session_state.frame_buffer and st.session_state.out is not None:
            for buffered_frame in st.session_state.frame_buffer:
                st.session_state.out.write(buffered_frame)
                st.session_state.processed_frames += 1
            st.session_state.frame_buffer.clear()
        
    except Exception as e:
        st.error(f"An error occurred: {str(e)}")
        raise e
    
    finally:
        # Ensure proper cleanup and save remaining frames
        if hasattr(st.session_state, 'frame_buffer') and st.session_state.frame_buffer and hasattr(st.session_state, 'out') and st.session_state.out is not None:
            for buffered_frame in st.session_state.frame_buffer:
                st.session_state.out.write(buffered_frame)
                st.session_state.processed_frames += 1
            st.session_state.frame_buffer.clear()
        
        # Release resources
        if hasattr(st.session_state, 'cap') and st.session_state.cap is not None:
            st.session_state.cap.release()
        
        if hasattr(st.session_state, 'out') and st.session_state.out is not None:
            st.session_state.out.release()
            cv2.destroyAllWindows()
        
        # Add a separator
        st.markdown("---")
        
        # Download section
        if st.session_state.processed_frames > 0:
            st.subheader("Download Processed Video")
            
            # Force flush and wait
            time.sleep(3)  # Increased wait time
            
            if (st.session_state.output_path and 
                Path(st.session_state.output_path).exists()):
                
                try:
                    with open(st.session_state.output_path, 'rb') as f:
                        video_data = f.read()
                        if len(video_data) > 1000:
                            st.success(f"Successfully processed {st.session_state.processed_frames} frames")
                            # Make download button more prominent
                            st.download_button(
                                label="📥 Download Processed Video",
                                data=video_data,
                                file_name=f"detected_video_{time.strftime('%Y%m%d_%H%M%S')}.mp4",
                                mime="video/mp4",
                                key="download_button"
                            )
                        else:
                            st.error("Error: Video file is empty or corrupted")
                            st.info("Try processing the video again with different settings")
                except Exception as e:
                    st.error(f"Error preparing download: {str(e)}")
                    st.info("Please try processing the video again")
            else:
                st.error("Output video file not found")
                st.info("Make sure to complete the video processing before downloading")

if __name__ == "__main__":
    main()