Abs6187's picture
Update app.py
58e9978 verified
raw
history blame
3.81 kB
import cv2
import torch
import numpy as np
import gradio as gr
from ultralytics import YOLO
from deep_sort_realtime.deep_sort import DeepSort
class ObjectTracker:
def __init__(self, person_model_path='yolov8n.pt'):
"""
Initialize object tracker with YOLO and DeepSort
"""
# Load YOLO model for person detection
self.model = YOLO(person_model_path)
# Initialize DeepSort tracker
self.tracker = DeepSort(
max_age=30, # Tracks can be lost for up to 30 frames
n_init=3, # Number of consecutive detections before track is confirmed
)
# Tracking statistics
self.person_count = 0
self.tracking_data = {}
def process_frame(self, frame):
"""
Process a single frame for object detection and tracking
"""
# Detect persons using YOLO
results = self.model(frame, classes=[0], conf=0.5)
# Extract bounding boxes and confidences
detections = []
for r in results:
boxes = r.boxes
for box in boxes:
# Convert to [x, y, w, h] format for DeepSort
x1, y1, x2, y2 = box.xyxy[0]
bbox = [x1.item(), y1.item(), (x2-x1).item(), (y2-y1).item()]
conf = box.conf.item()
detections.append((bbox, conf))
# Update tracks
if detections:
tracks = self.tracker.update_tracks(
detections,
frame=frame
)
# Annotate frame with tracking information
for track in tracks:
if not track.is_confirmed():
continue
track_id = track.track_id
ltrb = track.to_ltrb()
# Draw bounding box
cv2.rectangle(
frame,
(int(ltrb[0]), int(ltrb[1])),
(int(ltrb[2]), int(ltrb[3])),
(0, 255, 0),
2
)
# Add track ID
cv2.putText(
frame,
f'ID: {track_id}',
(int(ltrb[0]), int(ltrb[1]-10)),
cv2.FONT_HERSHEY_SIMPLEX,
0.9,
(0, 255, 0),
2
)
return frame
def process_video(input_video):
"""
Main video processing function for Gradio
"""
# Initialize tracker
tracker = ObjectTracker()
# Open input video
cap = cv2.VideoCapture(input_video)
# Prepare output video writer
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = cap.get(cv2.CAP_PROP_FPS)
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter('output_tracked.mp4', fourcc, fps, (width, height))
# Process video frames
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
# Process and annotate frame
processed_frame = tracker.process_frame(frame)
# Write processed frame
out.write(processed_frame)
# Release resources
cap.release()
out.release()
return 'output_tracked.mp4'
# Create Gradio interface
iface = gr.Interface(
fn=process_video,
inputs=gr.Video(label="Upload Video for Tracking"),
outputs=gr.Video(label="Tracked Video"),
title="Person Tracking with YOLO and DeepSort",
description="Upload a video to track and annotate person movements"
)
# Launch the interface
if __name__ == "__main__":
iface.launch()