Spaces:
Sleeping
Sleeping
import streamlit as st | |
import cv2 | |
import torch | |
import numpy as np | |
import time | |
import tempfile | |
from pathlib import Path | |
# Import detection utilities | |
from detection_utils import load_model, detect_objects, draw_boxes, ObjectTracker | |
def initialize_video_capture(input_source, video_file=None, url=None): | |
"""Initialize video capture and writer""" | |
cap = None | |
out = None | |
output_path = None | |
if input_source == "Video File" and video_file is not None: | |
# Save uploaded file to temp location | |
tfile = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') | |
tfile.write(video_file.read()) | |
tfile.flush() | |
video_path = tfile.name | |
# Open video capture | |
cap = cv2.VideoCapture(video_path) | |
if cap.isOpened(): | |
# Get video properties | |
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
fps = int(cap.get(cv2.CAP_PROP_FPS)) | |
# Ensure valid FPS | |
if fps <= 0: | |
fps = 30 | |
# Create output path in a temporary directory | |
temp_dir = tempfile.gettempdir() | |
output_path = str(Path(temp_dir) / 'detected_output.mp4') | |
# Try different codecs in order of preference | |
codecs = [ | |
('avc1', '.mp4'), | |
('mp4v', '.mp4'), | |
('XVID', '.avi') | |
] | |
for codec, ext in codecs: | |
try: | |
output_path = str(Path(temp_dir) / f'detected_output{ext}') | |
fourcc = cv2.VideoWriter_fourcc(*codec) | |
out = cv2.VideoWriter( | |
output_path, | |
fourcc, | |
fps, | |
(width, height), | |
isColor=True | |
) | |
# Test if writer is working | |
if out.isOpened(): | |
break | |
except Exception: | |
continue | |
if out is None or not out.isOpened(): | |
st.error("Failed to create video writer") | |
return None, None, None | |
elif input_source == "Live Stream URL" and url: | |
cap = cv2.VideoCapture(url) | |
return cap, out, output_path | |
def get_model_info(): | |
"""Return information about available YOLO models""" | |
return { | |
'yolov8n.pt': { | |
'name': 'YOLOv8 Nano', | |
'description': 'Smallest and fastest model. Best for CPU or low-power devices.', | |
'speed': '⚡⚡⚡⚡⚡', | |
'accuracy': '⭐⭐', | |
'size': '6.7 MB', | |
'details': 'Ideal for real-time applications with limited computing power.' | |
}, | |
'yolov8s.pt': { | |
'name': 'YOLOv8 Small', | |
'description': 'Small model balancing speed and accuracy.', | |
'speed': '⚡⚡⚡⚡', | |
'accuracy': '⭐⭐⭐', | |
'size': '22.4 MB', | |
'details': 'Good for general purpose detection with decent performance.' | |
}, | |
'yolov8m.pt': { | |
'name': 'YOLOv8 Medium', | |
'description': 'Medium-sized model with good balance.', | |
'speed': '⚡⚡⚡', | |
'accuracy': '⭐⭐⭐⭐', | |
'size': '52.2 MB', | |
'details': 'Recommended for standard detection tasks with good GPU.' | |
}, | |
'yolov8l.pt': { | |
'name': 'YOLOv8 Large', | |
'description': 'Large model with high accuracy.', | |
'speed': '⚡⚡', | |
'accuracy': '⭐⭐⭐⭐⭐', | |
'size': '87.7 MB', | |
'details': 'Best for high-accuracy requirements with good computing power.' | |
}, | |
'yolov8x.pt': { | |
'name': 'YOLOv8 XLarge', | |
'description': 'Extra large model with highest accuracy.', | |
'speed': '⚡', | |
'accuracy': '⭐⭐⭐⭐⭐⭐', | |
'size': '131.7 MB', | |
'details': 'Best for tasks requiring maximum accuracy, requires powerful GPU.' | |
} | |
} | |
def main(): | |
st.title("Real-Time Object Detection") | |
# Initialize session state | |
if 'tracker' not in st.session_state: | |
st.session_state.tracker = ObjectTracker() | |
if 'cap' not in st.session_state: | |
st.session_state.cap = None | |
if 'out' not in st.session_state: | |
st.session_state.out = None | |
if 'output_path' not in st.session_state: | |
st.session_state.output_path = None | |
if 'processed_frames' not in st.session_state: | |
st.session_state.processed_frames = 0 | |
if 'selected_model' not in st.session_state: | |
st.session_state.selected_model = 'yolov8x.pt' | |
if 'model' not in st.session_state: | |
st.session_state.model = None | |
# Sidebar settings | |
st.sidebar.title("Settings") | |
# Model selection | |
st.sidebar.subheader("Model Selection") | |
model_info = get_model_info() | |
selected_model = st.sidebar.selectbox( | |
"Choose YOLO Model", | |
options=list(model_info.keys()), | |
format_func=lambda x: model_info[x]['name'], | |
index=list(model_info.keys()).index(st.session_state.selected_model) | |
) | |
# Display model information | |
with st.sidebar.expander("Model Details", expanded=True): | |
st.markdown(f"**{model_info[selected_model]['name']}**") | |
st.write(model_info[selected_model]['description']) | |
st.write(f"Speed: {model_info[selected_model]['speed']}") | |
st.write(f"Accuracy: {model_info[selected_model]['accuracy']}") | |
st.write(f"Size: {model_info[selected_model]['size']}") | |
st.write(f"Details: {model_info[selected_model]['details']}") | |
# Add Load Model button | |
if st.sidebar.button("Load Selected Model"): | |
with st.spinner(f"Loading {model_info[selected_model]['name']}..."): | |
st.session_state.model = load_model(selected_model) | |
st.session_state.selected_model = selected_model | |
st.sidebar.success("Model loaded successfully!") | |
# Detection confidence | |
detection_confidence = st.sidebar.slider("Detection Confidence", 0.0, 1.0, 0.5) | |
# Input selection | |
input_source = st.radio("Select Input Source", ["Video File", "Live Stream URL"]) | |
try: | |
# Handle video input | |
if input_source == "Video File": | |
video_file = st.file_uploader("Upload Video", type=['mp4', 'avi']) | |
if video_file is not None: | |
st.session_state.cap, st.session_state.out, st.session_state.output_path = initialize_video_capture(input_source, video_file=video_file) | |
else: | |
url = st.text_input("Enter Stream URL") | |
if url: | |
st.session_state.cap, st.session_state.out, st.session_state.output_path = initialize_video_capture(input_source, url=url) | |
if st.session_state.cap is not None and not st.session_state.cap.isOpened(): | |
st.error("Error: Could not open video source") | |
st.stop() | |
# Create placeholder for video display | |
video_placeholder = st.empty() | |
# Initialize frame buffer in session state | |
if 'frame_buffer' not in st.session_state: | |
st.session_state.frame_buffer = [] | |
# Control buttons - Move them to sidebar to avoid duplication | |
st.sidebar.markdown("---") | |
st.sidebar.subheader("Controls") | |
start_button = st.sidebar.button("Start Detection") | |
stop_button = st.sidebar.button("Stop Detection") | |
if start_button: | |
if st.session_state.model is None: | |
st.error("Please load a model first using the 'Load Selected Model' button") | |
st.stop() | |
if st.session_state.cap is None: | |
st.error("Please upload a video or provide a stream URL first") | |
st.stop() | |
st.session_state.run_detection = True | |
st.session_state.processed_frames = 0 | |
st.session_state.frame_buffer = [] # Clear buffer on start | |
if stop_button: | |
st.session_state.run_detection = False | |
# Detection loop | |
while (hasattr(st.session_state, 'run_detection') and | |
st.session_state.run_detection and | |
st.session_state.cap is not None): | |
ret, frame = st.session_state.cap.read() | |
if not ret: | |
break | |
# Perform detection | |
detections = detect_objects(st.session_state.model, frame, detection_confidence) | |
# Draw boxes on frame | |
annotated_frame = draw_boxes(frame, detections, st.session_state.tracker) | |
# Add frame to buffer | |
st.session_state.frame_buffer.append(annotated_frame) | |
# Write frames to video periodically | |
if len(st.session_state.frame_buffer) >= 30: # Write every 30 frames | |
for buffered_frame in st.session_state.frame_buffer: | |
if st.session_state.out is not None: | |
st.session_state.out.write(buffered_frame) | |
st.session_state.processed_frames += 1 | |
st.session_state.frame_buffer.clear() | |
# Update display every 3rd frame | |
if st.session_state.processed_frames % 3 == 0: | |
video_placeholder.image(annotated_frame, channels="BGR") | |
# Minimal sleep to prevent UI freezing | |
time.sleep(0.001) | |
# Write remaining frames in buffer | |
if st.session_state.frame_buffer and st.session_state.out is not None: | |
for buffered_frame in st.session_state.frame_buffer: | |
st.session_state.out.write(buffered_frame) | |
st.session_state.processed_frames += 1 | |
st.session_state.frame_buffer.clear() | |
except Exception as e: | |
st.error(f"An error occurred: {str(e)}") | |
raise e | |
finally: | |
# Ensure proper cleanup and save remaining frames | |
if hasattr(st.session_state, 'frame_buffer') and st.session_state.frame_buffer and hasattr(st.session_state, 'out') and st.session_state.out is not None: | |
for buffered_frame in st.session_state.frame_buffer: | |
st.session_state.out.write(buffered_frame) | |
st.session_state.processed_frames += 1 | |
st.session_state.frame_buffer.clear() | |
# Release resources | |
if hasattr(st.session_state, 'cap') and st.session_state.cap is not None: | |
st.session_state.cap.release() | |
if hasattr(st.session_state, 'out') and st.session_state.out is not None: | |
st.session_state.out.release() | |
cv2.destroyAllWindows() | |
# Add a separator | |
st.markdown("---") | |
# Download section | |
if st.session_state.processed_frames > 0: | |
st.subheader("Download Processed Video") | |
# Force flush and wait | |
time.sleep(3) # Increased wait time | |
if (st.session_state.output_path and | |
Path(st.session_state.output_path).exists()): | |
try: | |
with open(st.session_state.output_path, 'rb') as f: | |
video_data = f.read() | |
if len(video_data) > 1000: | |
st.success(f"Successfully processed {st.session_state.processed_frames} frames") | |
# Make download button more prominent | |
st.download_button( | |
label="📥 Download Processed Video", | |
data=video_data, | |
file_name=f"detected_video_{time.strftime('%Y%m%d_%H%M%S')}.mp4", | |
mime="video/mp4", | |
key="download_button" | |
) | |
else: | |
st.error("Error: Video file is empty or corrupted") | |
st.info("Try processing the video again with different settings") | |
except Exception as e: | |
st.error(f"Error preparing download: {str(e)}") | |
st.info("Please try processing the video again") | |
else: | |
st.error("Output video file not found") | |
st.info("Make sure to complete the video processing before downloading") | |
if __name__ == "__main__": | |
main() |