File size: 3,807 Bytes
83b09db
58e9978
83b09db
b488bef
47ee765
58e9978
428e3e7
58e9978
 
 
 
 
 
 
 
 
 
 
 
 
b488bef
58e9978
 
 
2863de0
b488bef
 
 
 
58e9978
 
428e3e7
58e9978
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b488bef
58e9978
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b488bef
58e9978
47ee765
58e9978
47ee765
58e9978
 
 
 
 
 
 
 
 
 
47ee765
58e9978
 
47ee765
58e9978
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
428e3e7
58e9978
428e3e7
58e9978
 
 
 
 
428e3e7
 
47ee765
58e9978
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import cv2
import torch
import numpy as np
import gradio as gr
from ultralytics import YOLO
from deep_sort_realtime.deep_sort import DeepSort

class ObjectTracker:
    def __init__(self, person_model_path='yolov8n.pt'):
        """
        Initialize object tracker with YOLO and DeepSort
        """
        # Load YOLO model for person detection
        self.model = YOLO(person_model_path)
        
        # Initialize DeepSort tracker
        self.tracker = DeepSort(
            max_age=30,  # Tracks can be lost for up to 30 frames
            n_init=3,    # Number of consecutive detections before track is confirmed
        )
        
        # Tracking statistics
        self.person_count = 0
        self.tracking_data = {}

    def process_frame(self, frame):
        """
        Process a single frame for object detection and tracking
        """
        # Detect persons using YOLO
        results = self.model(frame, classes=[0], conf=0.5)
        
        # Extract bounding boxes and confidences
        detections = []
        for r in results:
            boxes = r.boxes
            for box in boxes:
                # Convert to [x, y, w, h] format for DeepSort
                x1, y1, x2, y2 = box.xyxy[0]
                bbox = [x1.item(), y1.item(), (x2-x1).item(), (y2-y1).item()]
                conf = box.conf.item()
                detections.append((bbox, conf))
        
        # Update tracks
        if detections:
            tracks = self.tracker.update_tracks(
                detections, 
                frame=frame
            )
            
            # Annotate frame with tracking information
            for track in tracks:
                if not track.is_confirmed():
                    continue
                
                track_id = track.track_id
                ltrb = track.to_ltrb()
                
                # Draw bounding box
                cv2.rectangle(
                    frame, 
                    (int(ltrb[0]), int(ltrb[1])), 
                    (int(ltrb[2]), int(ltrb[3])), 
                    (0, 255, 0), 
                    2
                )
                
                # Add track ID
                cv2.putText(
                    frame, 
                    f'ID: {track_id}', 
                    (int(ltrb[0]), int(ltrb[1]-10)), 
                    cv2.FONT_HERSHEY_SIMPLEX, 
                    0.9, 
                    (0, 255, 0), 
                    2
                )
        
        return frame

def process_video(input_video):
    """
    Main video processing function for Gradio
    """
    # Initialize tracker
    tracker = ObjectTracker()
    
    # Open input video
    cap = cv2.VideoCapture(input_video)
    
    # Prepare output video writer
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = cap.get(cv2.CAP_PROP_FPS)
    
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter('output_tracked.mp4', fourcc, fps, (width, height))
    
    # Process video frames
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break
        
        # Process and annotate frame
        processed_frame = tracker.process_frame(frame)
        
        # Write processed frame
        out.write(processed_frame)
    
    # Release resources
    cap.release()
    out.release()
    
    return 'output_tracked.mp4'

# Create Gradio interface
iface = gr.Interface(
    fn=process_video,
    inputs=gr.Video(label="Upload Video for Tracking"),
    outputs=gr.Video(label="Tracked Video"),
    title="Person Tracking with YOLO and DeepSort",
    description="Upload a video to track and annotate person movements"
)

# Launch the interface
if __name__ == "__main__":
    iface.launch()