Abs6187 commited on
Commit
58e9978
1 Parent(s): e7bd948

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +105 -177
app.py CHANGED
@@ -1,200 +1,128 @@
1
- import time
2
- import torch
3
  import cv2
 
4
  import numpy as np
5
  import gradio as gr
6
  from ultralytics import YOLO
7
- from deep_sort.utils.parser import get_config
8
- from deep_sort.deep_sort import DeepSort
9
-
10
- # Initialize YOLO and DeepSort
11
- deep_sort_weights = 'ckpt.t7'
12
- tracker = DeepSort(model_path=deep_sort_weights, max_age=80)
13
- model = YOLO("person_gun.pt")
14
 
15
- class ObjectDetector:
16
- def __init__(self):
17
- # Tracking variables
18
- self.unique_track_ids = set()
19
- self.track_labels = {}
20
- self.track_times = {}
21
- self.track_positions = {}
22
- self.running_counters = {}
23
- self.alert_person_ids = []
 
 
 
 
24
 
25
- # Detection parameters
26
- self.running_threshold = 0.5
27
- self.fps = 30 # Default FPS
28
 
29
  def process_frame(self, frame):
30
  """
31
  Process a single frame for object detection and tracking
32
  """
33
- # Reset alert tracking for this frame
34
- self.alert_person_ids.clear()
35
- og_frame = frame.copy()
36
-
37
- # Detect persons
38
- results = model(frame, device=0, classes=0, conf=0.75)
39
 
40
- for result in results:
41
- boxes = result.boxes
42
- cls = boxes.cls.tolist()
43
- conf = boxes.conf
44
- xywh = boxes.xywh
45
-
46
- pred_cls = np.array(cls)
47
- conf = conf.detach().cpu().numpy()
48
- bboxes_xywh = xywh.cpu().numpy()
49
-
50
- # Update tracking
51
- tracks = tracker.update(bboxes_xywh, conf, og_frame)
52
- active_track_ids = set()
53
-
54
- # Reset running status
55
- new_running_status = "No Running"
56
-
57
- for track in tracker.tracker.tracks:
58
- track_id = track.track_id
59
- x1, y1, x2, y2 = track.to_tlbr()
60
- w = x2 - x1
61
- h = y2 - y1
62
-
63
- # Define color for bounding box
64
- red_color = (0, 0, 255)
65
- blue_color = (255, 0, 0)
66
- green_color = (0, 255, 0)
67
- color_id = track_id % 3
68
- color = red_color if color_id == 0 else blue_color if color_id == 1 else green_color
69
- cv2.rectangle(og_frame, (int(x1), int(y1)), (int(x1 + w), int(y1 + h)), color, 2)
70
-
71
- # Initialize tracking for new tracks
72
- if track_id not in self.track_labels:
73
- self.track_labels[track_id] = "Person"
74
- self.track_times[track_id] = 0
75
- self.track_positions[track_id] = (x1, y1)
76
- self.running_counters[track_id] = 0
77
-
78
- self.track_times[track_id] += 1
79
- prev_x1, prev_y1 = self.track_positions[track_id]
80
- displacement = np.sqrt((x1 - prev_x1) ** 2 + (y1 - prev_y1) ** 2)
81
-
82
- # Calculate speed
83
- speed = displacement / self.fps if self.fps > 0 else 0
84
-
85
- self.track_positions[track_id] = (x1, y1)
86
-
87
- # Detect running
88
- if speed > self.running_threshold and w * h > 5000:
89
- self.running_counters[track_id] += 1
90
- if self.running_counters[track_id] > self.fps/2:
91
- self.track_labels[track_id] = "Running"
92
- new_running_status = "Running Detected"
93
- else:
94
- self.running_counters[track_id] = 0
95
- self.track_labels[track_id] = "Person"
96
-
97
- # Track time and potential alerts
98
- total_seconds = self.track_times[track_id] / self.fps if self.fps > 0 else 0
99
- minutes = int(total_seconds // 60)
100
- seconds = int(total_seconds % 60)
101
-
102
- # Trigger alert for prolonged stay
103
- if total_seconds > 60 and track_id not in self.alert_person_ids:
104
- self.alert_person_ids.append(track_id)
105
-
106
- # Add label to frame
107
- cv2.putText(og_frame, f"{self.track_labels[track_id]} {minutes:02}:{seconds:02}",
108
- (int(x1) + 10, int(y1) - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA)
109
 
110
- active_track_ids.add(track_id)
111
-
112
- # Update unique track IDs
113
- self.unique_track_ids.intersection_update(active_track_ids)
114
- self.unique_track_ids.update(active_track_ids)
115
-
116
- # Prepare result dictionary
117
- result_info = {
118
- 'person_count': len(self.unique_track_ids),
119
- 'running_status': new_running_status,
120
- 'prolonged_stay_ids': list(self.alert_person_ids)
121
- }
122
-
123
- return og_frame, result_info
124
-
125
- def process_input(self, input_media):
126
- """
127
- Process either video or webcam input
128
- """
129
- # Determine input type
130
- if isinstance(input_media, str): # Video file path
131
- cap = cv2.VideoCapture(input_media)
132
- else: # Webcam input
133
- cap = cv2.VideoCapture(0)
134
-
135
- # Prepare output video
136
- width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
137
- height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
138
- fps = cap.get(cv2.CAP_PROP_FPS) or 30
139
- self.fps = fps
140
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
141
- out = cv2.VideoWriter('output_detection.mp4', fourcc, fps, (width, height))
142
-
143
- # Processing loop
144
- frame_info_list = []
145
- while cap.isOpened():
146
- ret, frame = cap.read()
147
- if not ret:
148
- break
149
-
150
- # Process frame
151
- processed_frame, frame_info = self.process_frame(frame)
152
- out.write(processed_frame)
153
- frame_info_list.append(frame_info)
154
-
155
- # Release resources
156
- cap.release()
157
- out.release()
158
-
159
- return 'output_detection.mp4', frame_info_list
160
-
161
- # Create Gradio interface
162
- detector = ObjectDetector()
163
 
164
- def detect_interface(input_media):
165
  """
166
- Gradio interface function for detection
167
  """
168
- output_video, frame_info_list = detector.process_input(input_media)
 
 
 
 
 
 
 
 
 
169
 
170
- # Generate text summary
171
- summary = "Detection Summary:\n"
172
- if frame_info_list:
173
- # Take the last frame's information
174
- last_frame_info = frame_info_list[-1]
175
- summary += f"Total Persons Detected: {last_frame_info['person_count']}\n"
176
- summary += f"Running Status: {last_frame_info['running_status']}\n"
177
- if last_frame_info['prolonged_stay_ids']:
178
- summary += f"Prolonged Stay Detected - Person IDs: {last_frame_info['prolonged_stay_ids']}"
179
- else:
180
- summary += "No Prolonged Stay Detected"
181
 
182
- return output_video, summary
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
 
184
- # Gradio Interface
185
  iface = gr.Interface(
186
- fn=detect_interface,
187
- inputs=[
188
- gr.File(label="Upload Video", type="filepath"),
189
- gr.Webcam(label="Or Use Webcam")
190
- ],
191
- outputs=[
192
- gr.Video(label="Processed Video"),
193
- gr.Textbox(label="Detection Summary")
194
- ],
195
- title="Object Detection with Tracking",
196
- description="Upload a video or use webcam for real-time object detection and tracking"
197
  )
198
 
199
  # Launch the interface
200
- iface.launch()
 
 
 
 
1
  import cv2
2
+ import torch
3
  import numpy as np
4
  import gradio as gr
5
  from ultralytics import YOLO
6
+ from deep_sort_realtime.deep_sort import DeepSort
 
 
 
 
 
 
7
 
8
+ class ObjectTracker:
9
+ def __init__(self, person_model_path='yolov8n.pt'):
10
+ """
11
+ Initialize object tracker with YOLO and DeepSort
12
+ """
13
+ # Load YOLO model for person detection
14
+ self.model = YOLO(person_model_path)
15
+
16
+ # Initialize DeepSort tracker
17
+ self.tracker = DeepSort(
18
+ max_age=30, # Tracks can be lost for up to 30 frames
19
+ n_init=3, # Number of consecutive detections before track is confirmed
20
+ )
21
 
22
+ # Tracking statistics
23
+ self.person_count = 0
24
+ self.tracking_data = {}
25
 
26
  def process_frame(self, frame):
27
  """
28
  Process a single frame for object detection and tracking
29
  """
30
+ # Detect persons using YOLO
31
+ results = self.model(frame, classes=[0], conf=0.5)
 
 
 
 
32
 
33
+ # Extract bounding boxes and confidences
34
+ detections = []
35
+ for r in results:
36
+ boxes = r.boxes
37
+ for box in boxes:
38
+ # Convert to [x, y, w, h] format for DeepSort
39
+ x1, y1, x2, y2 = box.xyxy[0]
40
+ bbox = [x1.item(), y1.item(), (x2-x1).item(), (y2-y1).item()]
41
+ conf = box.conf.item()
42
+ detections.append((bbox, conf))
43
+
44
+ # Update tracks
45
+ if detections:
46
+ tracks = self.tracker.update_tracks(
47
+ detections,
48
+ frame=frame
49
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
+ # Annotate frame with tracking information
52
+ for track in tracks:
53
+ if not track.is_confirmed():
54
+ continue
55
+
56
+ track_id = track.track_id
57
+ ltrb = track.to_ltrb()
58
+
59
+ # Draw bounding box
60
+ cv2.rectangle(
61
+ frame,
62
+ (int(ltrb[0]), int(ltrb[1])),
63
+ (int(ltrb[2]), int(ltrb[3])),
64
+ (0, 255, 0),
65
+ 2
66
+ )
67
+
68
+ # Add track ID
69
+ cv2.putText(
70
+ frame,
71
+ f'ID: {track_id}',
72
+ (int(ltrb[0]), int(ltrb[1]-10)),
73
+ cv2.FONT_HERSHEY_SIMPLEX,
74
+ 0.9,
75
+ (0, 255, 0),
76
+ 2
77
+ )
78
+
79
+ return frame
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
+ def process_video(input_video):
82
  """
83
+ Main video processing function for Gradio
84
  """
85
+ # Initialize tracker
86
+ tracker = ObjectTracker()
87
+
88
+ # Open input video
89
+ cap = cv2.VideoCapture(input_video)
90
+
91
+ # Prepare output video writer
92
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
93
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
94
+ fps = cap.get(cv2.CAP_PROP_FPS)
95
 
96
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
97
+ out = cv2.VideoWriter('output_tracked.mp4', fourcc, fps, (width, height))
 
 
 
 
 
 
 
 
 
98
 
99
+ # Process video frames
100
+ while cap.isOpened():
101
+ ret, frame = cap.read()
102
+ if not ret:
103
+ break
104
+
105
+ # Process and annotate frame
106
+ processed_frame = tracker.process_frame(frame)
107
+
108
+ # Write processed frame
109
+ out.write(processed_frame)
110
+
111
+ # Release resources
112
+ cap.release()
113
+ out.release()
114
+
115
+ return 'output_tracked.mp4'
116
 
117
+ # Create Gradio interface
118
  iface = gr.Interface(
119
+ fn=process_video,
120
+ inputs=gr.Video(label="Upload Video for Tracking"),
121
+ outputs=gr.Video(label="Tracked Video"),
122
+ title="Person Tracking with YOLO and DeepSort",
123
+ description="Upload a video to track and annotate person movements"
 
 
 
 
 
 
124
  )
125
 
126
  # Launch the interface
127
+ if __name__ == "__main__":
128
+ iface.launch()