limitedonly41's picture
Create app.py
79df048 verified
import cv2
import torch
import gradio as gr
from ultralytics import YOLO
import time
# Check if MPS (Metal Performance Shaders) is available, otherwise use CPU
if torch.backends.mps.is_available() and torch.backends.mps.is_built():
device = torch.device('mps')
print("MPS available, using MPS.")
else:
device = torch.device('cpu')
print("MPS not available, using CPU.")
# Load the YOLOv8 model
model = YOLO("yolov8n.pt").to(device)
# Classes to count: 0 = person, 2 = car
classes_to_count = [0, 2] # person and car classes for counting
# Initialize unique ID storage for each class
unique_people_ids = set()
unique_car_ids = set()
def process_video(video_input):
global unique_people_ids, unique_car_ids
unique_people_ids = set()
unique_car_ids = set()
# Open the input video
cap = cv2.VideoCapture(video_input)
assert cap.isOpened(), "Error reading video file"
# Get video properties
w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS))
# Set up video writer to store annotated video as frames are processed
output_frames = []
frame_counter = 0
frame_skip = 5 # Process every 3rd frame
while cap.isOpened():
success, frame = cap.read()
if not success:
break
if frame_counter % frame_skip != 0:
frame_counter += 1
continue
# Calculate video timestamp based on frame number and FPS
video_time_elapsed = frame_counter / fps
video_timestamp = time.strftime('%H:%M:%S', time.gmtime(video_time_elapsed))
# Run object detection and tracking on the frame
results = model.track(frame, persist=True, device=device, classes=classes_to_count, verbose=False, conf=0.4)
# Initialize counters for current frame
people_count = 0
car_count = 0
# Process detections to track unique IDs
for det in results[0].boxes:
try:
object_id = int(det.id[0])
except:
pass
if object_id is None:
continue # Skip objects without an ID
if det.cls == 0: # person class
if object_id not in unique_people_ids:
unique_people_ids.add(object_id) # Add unique person ID
people_count += 1
elif det.cls == 2: # car class
if object_id not in unique_car_ids:
unique_car_ids.add(object_id) # Add unique car ID
car_count += 1
# Annotate the frame with the current and total counts of unique objects
annotated_frame = results[0].plot()
# Display unique people and car count on the frame
cv2.putText(annotated_frame, f'Unique People: {len(unique_people_ids)} | Unique Cars: {len(unique_car_ids)}', (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
# Store the annotated frame
output_frames.append(annotated_frame)
# Increment frame counter
frame_counter += 1
cap.release()
# Return processed video frames
return output_frames
def video_pipeline(video_file):
# Convert video into individual frames
output_frames = process_video(video_file)
# Encode the frames back into a video
output_video_path = 'output.mp4'
h, w, _ = output_frames[0].shape
out = cv2.VideoWriter(output_video_path, cv2.VideoWriter_fourcc(*'mp4v'), 20, (w, h))
for frame in output_frames:
out.write(frame)
out.release()
return output_video_path
# Gradio Interface
title = "YOLOv8 Object Tracking with Unique ID Counting"
description = "Upload a video to detect and count unique people and cars using YOLOv8."
interface = gr.Interface(
fn=video_pipeline,
inputs=gr.Video(label="Input Video"),
outputs=gr.Video(label="Processed Video"),
title=title,
description=description,
live=True
)
# Launch Gradio interface
interface.launch()