Spaces:
Configuration error
Configuration error
import cv2 | |
import torch | |
import gradio as gr | |
from ultralytics import YOLO | |
import time | |
# Check if MPS (Metal Performance Shaders) is available, otherwise use CPU | |
if torch.backends.mps.is_available() and torch.backends.mps.is_built(): | |
device = torch.device('mps') | |
print("MPS available, using MPS.") | |
else: | |
device = torch.device('cpu') | |
print("MPS not available, using CPU.") | |
# Load the YOLOv8 model | |
model = YOLO("yolov8n.pt").to(device) | |
# Classes to count: 0 = person, 2 = car | |
classes_to_count = [0, 2] # person and car classes for counting | |
# Initialize unique ID storage for each class | |
unique_people_ids = set() | |
unique_car_ids = set() | |
def process_video(video_input): | |
global unique_people_ids, unique_car_ids | |
unique_people_ids = set() | |
unique_car_ids = set() | |
# Open the input video | |
cap = cv2.VideoCapture(video_input) | |
assert cap.isOpened(), "Error reading video file" | |
# Get video properties | |
w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS)) | |
# Set up video writer to store annotated video as frames are processed | |
output_frames = [] | |
frame_counter = 0 | |
frame_skip = 5 # Process every 3rd frame | |
while cap.isOpened(): | |
success, frame = cap.read() | |
if not success: | |
break | |
if frame_counter % frame_skip != 0: | |
frame_counter += 1 | |
continue | |
# Calculate video timestamp based on frame number and FPS | |
video_time_elapsed = frame_counter / fps | |
video_timestamp = time.strftime('%H:%M:%S', time.gmtime(video_time_elapsed)) | |
# Run object detection and tracking on the frame | |
results = model.track(frame, persist=True, device=device, classes=classes_to_count, verbose=False, conf=0.4) | |
# Initialize counters for current frame | |
people_count = 0 | |
car_count = 0 | |
# Process detections to track unique IDs | |
for det in results[0].boxes: | |
try: | |
object_id = int(det.id[0]) | |
except: | |
pass | |
if object_id is None: | |
continue # Skip objects without an ID | |
if det.cls == 0: # person class | |
if object_id not in unique_people_ids: | |
unique_people_ids.add(object_id) # Add unique person ID | |
people_count += 1 | |
elif det.cls == 2: # car class | |
if object_id not in unique_car_ids: | |
unique_car_ids.add(object_id) # Add unique car ID | |
car_count += 1 | |
# Annotate the frame with the current and total counts of unique objects | |
annotated_frame = results[0].plot() | |
# Display unique people and car count on the frame | |
cv2.putText(annotated_frame, f'Unique People: {len(unique_people_ids)} | Unique Cars: {len(unique_car_ids)}', (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2) | |
# Store the annotated frame | |
output_frames.append(annotated_frame) | |
# Increment frame counter | |
frame_counter += 1 | |
cap.release() | |
# Return processed video frames | |
return output_frames | |
def video_pipeline(video_file): | |
# Convert video into individual frames | |
output_frames = process_video(video_file) | |
# Encode the frames back into a video | |
output_video_path = 'output.mp4' | |
h, w, _ = output_frames[0].shape | |
out = cv2.VideoWriter(output_video_path, cv2.VideoWriter_fourcc(*'mp4v'), 20, (w, h)) | |
for frame in output_frames: | |
out.write(frame) | |
out.release() | |
return output_video_path | |
# Gradio Interface | |
title = "YOLOv8 Object Tracking with Unique ID Counting" | |
description = "Upload a video to detect and count unique people and cars using YOLOv8." | |
interface = gr.Interface( | |
fn=video_pipeline, | |
inputs=gr.Video(label="Input Video"), | |
outputs=gr.Video(label="Processed Video"), | |
title=title, | |
description=description, | |
live=True | |
) | |
# Launch Gradio interface | |
interface.launch() | |