Spaces:
Sleeping
Sleeping
import os | |
import tempfile | |
import cv2 | |
import streamlit as st | |
import PIL | |
from ultralytics import YOLO | |
# Required libraries (ensure these are in your requirements.txt): | |
# streamlit | |
# opencv-python-headless | |
# ultralytics | |
# Pillow | |
# Replace with your model's URL or local path | |
model_path = 'https://huggingface.co/spaces/tstone87/ccr-colorado/resolve/main/best.pt | |
' | |
# Configure the page for Hugging Face Spaces | |
st.set_page_config( | |
page_title="Fire Watch using AI vision models", | |
page_icon="🔥", | |
layout="wide", | |
initial_sidebar_state="expanded" | |
) | |
# Sidebar for file upload and settings | |
with st.sidebar: | |
st.header("IMAGE/VIDEO UPLOAD") | |
source_file = st.file_uploader( | |
"Choose an image or video...", type=("jpg", "jpeg", "png", "bmp", "webp", "mp4")) | |
confidence = float(st.slider("Select Model Confidence", 25, 100, 40)) / 100 | |
video_option = st.selectbox( | |
"Select Video Shortening Option", | |
["Original FPS", "1 fps", "1 frame per 5 seconds", "1 frame per 10 seconds", "1 frame per 15 seconds"] | |
) | |
# Main page header and introduction images | |
st.title("WildfireWatch: Detecting Wildfire using AI") | |
col1, col2 = st.columns(2) | |
with col1: | |
st.image("https://huggingface.co/spaces/tstone87/ccr-colorado/resolve/main/Fire_1.jpeg", use_column_width=True) | |
with col2: | |
st.image("https://huggingface.co/spaces/tstone87/ccr-colorado/resolve/main/Fire_3.png", use_column_width=True) | |
st.markdown(""" | |
Fires in Colorado present a serious challenge, threatening urban communities, highways, and even remote areas. Early detection is critical to mitigating risks. WildfireWatch leverages the YOLOv8 model for real-time fire and smoke detection in images and videos. | |
""") | |
st.markdown("---") | |
st.header("Fire Detection:") | |
col1, col2 = st.columns(2) | |
if source_file: | |
if source_file.type.split('/')[0] == 'image': | |
uploaded_image = PIL.Image.open(source_file) | |
st.image(uploaded_image, caption="Uploaded Image", use_column_width=True) | |
else: | |
tfile = tempfile.NamedTemporaryFile(delete=False) | |
tfile.write(source_file.read()) | |
vidcap = cv2.VideoCapture(tfile.name) | |
else: | |
st.info("Please upload an image or video file to begin.") | |
# Load the YOLO model | |
try: | |
model = YOLO(model_path) | |
except Exception as ex: | |
st.error(f"Unable to load model. Check the specified path: {model_path}") | |
st.error(ex) | |
if st.sidebar.button("Let's Detect Wildfire"): | |
if not source_file: | |
st.warning("No file uploaded!") | |
elif source_file.type.split('/')[0] == 'image': | |
# Process image input | |
res = model.predict(uploaded_image, conf=confidence) | |
boxes = res[0].boxes | |
res_plotted = res[0].plot()[:, :, ::-1] | |
with col2: | |
st.image(res_plotted, caption='Detected Image', use_column_width=True) | |
with st.expander("Detection Results"): | |
for box in boxes: | |
st.write(box.xywh) | |
else: | |
# Process video input and shorten video based on sampling option | |
processed_frames = [] | |
frame_count = 0 | |
# Video properties | |
orig_fps = vidcap.get(cv2.CAP_PROP_FPS) | |
width = int(vidcap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
height = int(vidcap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
# Determine sampling interval and output fps | |
if video_option == "Original FPS": | |
sample_interval = 1 # process every frame | |
output_fps = orig_fps | |
elif video_option == "1 fps": | |
sample_interval = int(orig_fps) if orig_fps > 0 else 1 | |
output_fps = 1 | |
elif video_option == "1 frame per 5 seconds": | |
sample_interval = int(orig_fps * 5) if orig_fps > 0 else 5 | |
output_fps = 1 | |
elif video_option == "1 frame per 10 seconds": | |
sample_interval = int(orig_fps * 10) if orig_fps > 0 else 10 | |
output_fps = 1 | |
elif video_option == "1 frame per 15 seconds": | |
sample_interval = int(orig_fps * 15) if orig_fps > 0 else 15 | |
output_fps = 1 | |
else: | |
sample_interval = 1 | |
output_fps = orig_fps | |
success, image = vidcap.read() | |
while success: | |
if frame_count % sample_interval == 0: | |
res = model.predict(image, conf=confidence) | |
res_plotted = res[0].plot()[:, :, ::-1] | |
processed_frames.append(res_plotted) | |
with col2: | |
st.image(res_plotted, caption=f'Detected Frame {frame_count}', use_column_width=True) | |
with st.expander("Detection Results"): | |
for box in res[0].boxes: | |
st.write(box.xywh) | |
frame_count += 1 | |
success, image = vidcap.read() | |
if processed_frames: | |
temp_video_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') | |
fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
out = cv2.VideoWriter(temp_video_file.name, fourcc, output_fps, (width, height)) | |
for frame in processed_frames: | |
out.write(frame) | |
out.release() | |
st.success("Shortened video created successfully!") | |
with open(temp_video_file.name, 'rb') as video_file: | |
st.download_button( | |
label="Download Shortened Video", | |
data=video_file.read(), | |
file_name="shortened_video.mp4", | |
mime="video/mp4" | |
) | |
else: | |
st.error("No frames were processed from the video.") | |