asvs's picture
Add requirements and implement Gradio interface for people counting
2704b9b
import gradio as gr
import cv2
import numpy as np
from ultralytics import YOLO
from collections import defaultdict
import tempfile
import os
class PersonCounter:
def __init__(self, line_position=0.5):
self.model = YOLO("yolov8n.pt")
self.tracker = defaultdict(list)
self.crossed_ids = set()
self.line_position = line_position
self.count = 0
def process_frame(self, frame):
height, width = frame.shape[:2]
line_y = int(height * self.line_position)
# Draw counting line
cv2.line(frame, (0, line_y), (width, line_y), (0, 255, 0), 2)
# Run detection and tracking
results = self.model.track(frame, persist=True, classes=[0])
if results[0].boxes.id is not None:
boxes = results[0].boxes.xyxy.cpu().numpy()
track_ids = results[0].boxes.id.cpu().numpy().astype(int)
for box, track_id in zip(boxes, track_ids):
# Draw bounding box
cv2.rectangle(frame, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])),
(255, 0, 0), 2)
# Get feet position
center_x = (box[0] + box[2]) / 2
feet_y = box[3]
# Draw tracking point
cv2.circle(frame, (int(center_x), int(feet_y)), 5, (0, 255, 255), -1)
# Store tracking history
if track_id in self.tracker:
prev_y = self.tracker[track_id][-1]
# Check if person has crossed the line
if prev_y < line_y and feet_y >= line_y and track_id not in self.crossed_ids:
self.crossed_ids.add(track_id)
self.count += 1
# Draw crossing indicator
cv2.circle(frame, (int(center_x), int(line_y)), 8, (0, 0, 255), -1)
self.tracker[track_id] = [feet_y]
# Draw count with background
count_text = f"Count: {self.count}"
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 1.5
thickness = 3
(text_width, text_height), _ = cv2.getTextSize(count_text, font, font_scale, thickness)
cv2.rectangle(frame, (10, 10), (20 + text_width, 20 + text_height),
(0, 0, 0), -1)
cv2.putText(frame, count_text, (15, 15 + text_height),
font, font_scale, (0, 255, 0), thickness)
return frame
def process_video(video_path, progress=gr.Progress()):
# Create temp directory for output
temp_dir = tempfile.mkdtemp()
output_path = os.path.join(temp_dir, "result.mp4")
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
raise ValueError("Could not open video file")
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = int(cap.get(cv2.CAP_PROP_FPS))
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
writer = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
counter = PersonCounter(line_position=0.5)
for frame_idx in progress.tqdm(range(total_frames)):
ret, frame = cap.read()
if not ret:
break
processed_frame = counter.process_frame(frame)
writer.write(processed_frame)
cap.release()
writer.release()
return output_path, f"Final count: {counter.count} people entered"
# Create Gradio interface
demo = gr.Interface(
fn=process_video,
inputs=gr.Video(label="Upload a video file"),
outputs=[
gr.Video(label="Processed Video"),
gr.Textbox(label="Results")
],
title="Store Entry People Counter",
description="Upload a video to count the number of people entering through a line. The green line represents the counting threshold, blue boxes show detected people, and the counter increases when someone crosses the line from top to bottom.",
examples=[],
cache_examples=False
)
if __name__ == "__main__":
demo.launch()