Spaces:
Sleeping
Sleeping
| import os | |
| import time | |
| import tempfile | |
| import cv2 | |
| import streamlit as st | |
| from ultralytics import YOLO | |
| from huggingface_hub import hf_hub_url, cached_download | |
| def load_model(): | |
| repo_id = "BreakIntoData/cv_workshop" | |
| filename = "soccer_ball.pt" | |
| # Create a URL for the model file on the Hugging Face Hub | |
| model_url = hf_hub_url(repo_id, filename) | |
| # Download the model file from the Hub and cache it locally | |
| cached_model_path = cached_download(model_url) | |
| # Rename the file to have a .pt extension | |
| new_cached_model_path = f"{cached_model_path}.pt" | |
| os.rename(cached_model_path, new_cached_model_path) | |
| print(f"Downloaded model to {new_cached_model_path}") | |
| # Load the model using YOLO from the cached model file | |
| return YOLO(new_cached_model_path) | |
| def process_video(video_path, output_path): | |
| cap = cv2.VideoCapture(video_path) | |
| width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| fps = int(cap.get(cv2.CAP_PROP_FPS)) | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| fourcc = cv2.VideoWriter_fourcc(*'avc1') | |
| out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) | |
| progress_text = "Processing video... Please wait." | |
| progress_bar = st.progress(0) | |
| status_text = st.empty() | |
| time_text = st.empty() | |
| start_time = time.time() | |
| for frame_num in range(total_frames): | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| results = model(frame) | |
| annotated_frame = results[0].plot() | |
| out.write(annotated_frame) | |
| # Update progress | |
| progress = (frame_num + 1) / total_frames | |
| elapsed_time = time.time() - start_time | |
| estimated_total_time = elapsed_time / progress | |
| remaining_time = estimated_total_time - elapsed_time | |
| progress_bar.progress(progress) | |
| status_text.text(f"Processing frame {frame_num+1}/{total_frames}") | |
| time_text.text(f"Elapsed time: {elapsed_time:.2f}s | Estimated time remaining: {remaining_time:.2f}s") | |
| cap.release() | |
| out.release() | |
| progress_bar.empty() | |
| status_text.text(f"Processed {total_frames} frames") | |
| time_text.text(f"Total time: {time.time() - start_time:.2f}s") | |
| model = load_model() | |
| st.title("Soccer Ball Detection App") | |
| # Sidebar for options | |
| st.sidebar.header("Options") | |
| video_option = st.sidebar.radio("Choose video source:", ("Use preset video", "Upload video")) | |
| if video_option == "Upload video": | |
| uploaded_file = st.sidebar.file_uploader("Choose a video file", type=["mp4", "avi", "mov"]) | |
| if uploaded_file is not None: | |
| tfile = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") | |
| tfile.write(uploaded_file.read()) | |
| video_path = tfile.name | |
| else: | |
| preset_videos = { | |
| "Ronaldo": "preset_videos/Ronaldo.mp4", | |
| "Sancho": "preset_videos/CityUtdR video.mp4", | |
| "Messi": "preset_videos/Messi.mp4", | |
| } | |
| selected_video = st.sidebar.selectbox("Select a preset video", list(preset_videos.keys())) | |
| video_path = preset_videos[selected_video] | |
| if 'video_path' in locals(): | |
| st.header("Original Video") | |
| st.video(video_path) | |
| if st.button("Detect"): | |
| temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") | |
| with st.spinner("Processing video..."): | |
| process_video(video_path, temp_file.name) | |
| st.success("Video processing complete!") | |
| st.header("Processed Video") | |
| st.video(temp_file.name) | |
| # Add download link | |
| with open(temp_file.name, "rb") as file: | |
| btn = st.download_button( | |
| label="Download Video", | |
| data=file, | |
| file_name="processed_video.mp4", | |
| mime="video/mp4" | |
| ) | |
| # # Clean up temporary files | |
| # os.unlink(temp_file.name) | |
| if video_option == "Upload video": | |
| os.unlink(video_path) | |
| st.sidebar.markdown("---") | |
| st.sidebar.write("Developed with ❤️ by Break Into Data") | |