Abs6187 commited on
Commit
b488bef
·
verified ·
1 Parent(s): cba3b2a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +177 -142
app.py CHANGED
@@ -1,164 +1,199 @@
1
- import gradio as gr
 
2
  import cv2
3
  import numpy as np
4
- from tensorflow.keras.models import load_model
5
- from sklearn.preprocessing import StandardScaler
6
  from ultralytics import YOLO
 
 
7
 
8
- # Load models
9
- lstm_model = load_model('suspicious_activity_model.h5')
10
- yolo_model = YOLO('yolov8n-pose.pt') # Ensure this model supports keypoint detection
11
- scaler = StandardScaler()
12
 
13
- def extract_keypoints(frame):
14
- """
15
- Extracts normalized keypoints from a frame using YOLO pose model.
16
- """
17
- results = yolo_model(frame, verbose=False)
18
- for r in results:
19
- if r.keypoints is not None and len(r.keypoints) > 0:
20
- # Extract the first detected person's keypoints
21
- keypoints = r.keypoints.xyn.tolist()[0] # Use the first person's keypoints
22
- flattened_keypoints = [kp for keypoint in keypoints for kp in keypoint[:2]] # Flatten x, y values
23
- return flattened_keypoints
24
- return None # Return None if no keypoints are detected
25
-
26
- def process_input(input_media):
27
- """
28
- Process either a video or an image for suspicious activity detection
29
- """
30
- # Determine if input is a video or image path
31
- is_video = input_media.lower().endswith(('.mp4', '.avi', '.mov'))
32
-
33
- if is_video:
34
- return process_video(input_media)
35
- else:
36
- return process_image(input_media)
37
 
38
- def process_image(image_path):
39
- """
40
- Process a single image for suspicious activity detection
41
- """
42
- # Read the image
43
- frame = cv2.imread(image_path)
44
-
45
- # Perform YOLO detection
46
- results = yolo_model(frame, verbose=False)
47
- for box in results[0].boxes:
48
- cls = int(box.cls[0]) # Class ID
49
- confidence = float(box.conf[0])
50
 
51
- # Detect persons only (class_id 0 for 'person')
52
- if cls == 0 and confidence > 0.5:
53
- x1, y1, x2, y2 = map(int, box.xyxy[0]) # Bounding box coordinates
54
-
55
- # Extract ROI for classification
56
- roi = frame[y1:y2, x1:x2]
57
- if roi.size > 0:
58
- # Preprocess ROI to extract keypoints
59
- keypoints = extract_keypoints(roi)
60
-
61
- if keypoints is not None and len(keypoints) > 0:
62
- # Standardize and reshape keypoints for LSTM input
63
- keypoints_scaled = scaler.fit_transform([keypoints]) # Standardize features
64
- keypoints_reshaped = keypoints_scaled.reshape((1, 1, len(keypoints))) # Reshape for LSTM
65
-
66
- # Predict with LSTM model
67
- prediction = (lstm_model.predict(keypoints_reshaped) > 0.5).astype(int)[0][0]
68
-
69
- # Draw bounding box and label
70
- color = (0, 0, 255) if prediction == 1 else (0, 255, 0)
71
- label = 'Suspicious' if prediction == 1 else 'Normal'
72
- cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
73
- cv2.putText(frame, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
74
- else:
75
- print("No valid keypoints detected for ROI. Skipping.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  else:
77
- print("ROI size is zero. Skipping.")
78
-
79
- # Save the processed image
80
- output_path = 'output_image.jpg'
81
- cv2.imwrite(output_path, frame)
82
- return output_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
- def process_video(input_video):
 
 
 
 
 
 
 
 
 
85
  """
86
- Process video for suspicious activity detection
87
  """
88
- # Open video capture
89
- cap = cv2.VideoCapture(input_video)
90
-
91
- # Prepare to save output video
92
- width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
93
- height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
94
- fps = cap.get(cv2.CAP_PROP_FPS)
95
 
96
- # Create VideoWriter object
97
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
98
- out = cv2.VideoWriter('output_video.mp4', fourcc, fps, (width, height))
 
 
 
 
 
 
 
 
99
 
100
- # Process each frame
101
- while cap.isOpened():
102
- ret, frame = cap.read()
103
- if not ret:
104
- break
105
-
106
- # Perform YOLO detection
107
- results = yolo_model(frame, verbose=False)
108
- for box in results[0].boxes:
109
- cls = int(box.cls[0]) # Class ID
110
- confidence = float(box.conf[0])
111
-
112
- # Detect persons only (class_id 0 for 'person')
113
- if cls == 0 and confidence > 0.5:
114
- x1, y1, x2, y2 = map(int, box.xyxy[0]) # Bounding box coordinates
115
-
116
- # Extract ROI for classification
117
- roi = frame[y1:y2, x1:x2]
118
- if roi.size > 0:
119
- # Preprocess ROI to extract keypoints
120
- keypoints = extract_keypoints(roi)
121
-
122
- if keypoints is not None and len(keypoints) > 0:
123
- # Standardize and reshape keypoints for LSTM input
124
- keypoints_scaled = scaler.fit_transform([keypoints]) # Standardize features
125
- keypoints_reshaped = keypoints_scaled.reshape((1, 1, len(keypoints))) # Reshape for LSTM
126
-
127
- # Predict with LSTM model
128
- prediction = (lstm_model.predict(keypoints_reshaped) > 0.5).astype(int)[0][0]
129
-
130
- # Draw bounding box and label
131
- color = (0, 0, 255) if prediction == 1 else (0, 255, 0)
132
- label = 'Suspicious' if prediction == 1 else 'Normal'
133
- cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
134
- cv2.putText(frame, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
135
- else:
136
- print("No valid keypoints detected for ROI. Skipping frame.")
137
- else:
138
- print("ROI size is zero. Skipping frame.")
139
-
140
- # Write processed frame to output video
141
- out.write(frame)
142
-
143
- # Release resources
144
- cap.release()
145
- out.release()
146
-
147
- return 'output_video.mp4'
148
 
149
- # Create Gradio interface
150
  iface = gr.Interface(
151
- fn=process_input,
152
  inputs=[
153
- gr.File(label="Upload Image or Video",
154
- file_types=['image', 'video'],
155
- type="filepath")
156
  ],
157
  outputs=[
158
- gr.File(label="Processed Media")
 
159
  ],
160
- title="Suspicious Activity Detection",
161
- description="Upload an image or video to detect suspicious activities using YOLO and LSTM models. Suspicious activities will be marked with red bounding boxes, normal activities with green."
162
  )
163
 
164
  # Launch the interface
 
1
+ import time
2
+ import torch
3
  import cv2
4
  import numpy as np
5
+ import gradio as gr
 
6
  from ultralytics import YOLO
7
+ from deep_sort.utils.parser import get_config
8
+ from deep_sort.deep_sort import DeepSort
9
 
10
+ # Initialize YOLO and DeepSort
11
+ deep_sort_weights = 'ckpt.t7'
12
+ tracker = DeepSort(model_path=deep_sort_weights, max_age=80)
13
+ model = YOLO("person_gun.pt")
14
 
15
+ class ObjectDetector:
16
+ def __init__(self):
17
+ # Tracking variables
18
+ self.unique_track_ids = set()
19
+ self.track_labels = {}
20
+ self.track_times = {}
21
+ self.track_positions = {}
22
+ self.running_counters = {}
23
+ self.alert_person_ids = []
24
+
25
+ # Detection parameters
26
+ self.running_threshold = 0.5
27
+ self.fps = 30 # Default FPS
 
 
 
 
 
 
 
 
 
 
 
28
 
29
+ def process_frame(self, frame):
30
+ """
31
+ Process a single frame for object detection and tracking
32
+ """
33
+ # Reset alert tracking for this frame
34
+ self.alert_person_ids.clear()
35
+ og_frame = frame.copy()
36
+
37
+ # Detect persons
38
+ results = model(frame, device=0, classes=0, conf=0.75)
 
 
39
 
40
+ for result in results:
41
+ boxes = result.boxes
42
+ cls = boxes.cls.tolist()
43
+ conf = boxes.conf
44
+ xywh = boxes.xywh
45
+
46
+ pred_cls = np.array(cls)
47
+ conf = conf.detach().cpu().numpy()
48
+ bboxes_xywh = xywh.cpu().numpy()
49
+
50
+ # Update tracking
51
+ tracks = tracker.update(bboxes_xywh, conf, og_frame)
52
+ active_track_ids = set()
53
+
54
+ # Reset running status
55
+ new_running_status = "No Running"
56
+
57
+ for track in tracker.tracker.tracks:
58
+ track_id = track.track_id
59
+ x1, y1, x2, y2 = track.to_tlbr()
60
+ w = x2 - x1
61
+ h = y2 - y1
62
+
63
+ # Define color for bounding box
64
+ red_color = (0, 0, 255)
65
+ blue_color = (255, 0, 0)
66
+ green_color = (0, 255, 0)
67
+ color_id = track_id % 3
68
+ color = red_color if color_id == 0 else blue_color if color_id == 1 else green_color
69
+ cv2.rectangle(og_frame, (int(x1), int(y1)), (int(x1 + w), int(y1 + h)), color, 2)
70
+
71
+ # Initialize tracking for new tracks
72
+ if track_id not in self.track_labels:
73
+ self.track_labels[track_id] = "Person"
74
+ self.track_times[track_id] = 0
75
+ self.track_positions[track_id] = (x1, y1)
76
+ self.running_counters[track_id] = 0
77
+
78
+ self.track_times[track_id] += 1
79
+ prev_x1, prev_y1 = self.track_positions[track_id]
80
+ displacement = np.sqrt((x1 - prev_x1) ** 2 + (y1 - prev_y1) ** 2)
81
+
82
+ # Calculate speed
83
+ speed = displacement / self.fps if self.fps > 0 else 0
84
+
85
+ self.track_positions[track_id] = (x1, y1)
86
+
87
+ # Detect running
88
+ if speed > self.running_threshold and w * h > 5000:
89
+ self.running_counters[track_id] += 1
90
+ if self.running_counters[track_id] > self.fps/2:
91
+ self.track_labels[track_id] = "Running"
92
+ new_running_status = "Running Detected"
93
  else:
94
+ self.running_counters[track_id] = 0
95
+ self.track_labels[track_id] = "Person"
96
+
97
+ # Track time and potential alerts
98
+ total_seconds = self.track_times[track_id] / self.fps if self.fps > 0 else 0
99
+ minutes = int(total_seconds // 60)
100
+ seconds = int(total_seconds % 60)
101
+
102
+ # Trigger alert for prolonged stay
103
+ if total_seconds > 60 and track_id not in self.alert_person_ids:
104
+ self.alert_person_ids.append(track_id)
105
+
106
+ # Add label to frame
107
+ cv2.putText(og_frame, f"{self.track_labels[track_id]} {minutes:02}:{seconds:02}",
108
+ (int(x1) + 10, int(y1) - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA)
109
+
110
+ active_track_ids.add(track_id)
111
+
112
+ # Update unique track IDs
113
+ self.unique_track_ids.intersection_update(active_track_ids)
114
+ self.unique_track_ids.update(active_track_ids)
115
+
116
+ # Prepare result dictionary
117
+ result_info = {
118
+ 'person_count': len(self.unique_track_ids),
119
+ 'running_status': new_running_status,
120
+ 'prolonged_stay_ids': list(self.alert_person_ids)
121
+ }
122
+
123
+ return og_frame, result_info
124
+
125
+ def process_input(self, input_media):
126
+ """
127
+ Process either video or webcam input
128
+ """
129
+ # Determine input type
130
+ if isinstance(input_media, str): # Video file path
131
+ cap = cv2.VideoCapture(input_media)
132
+ else: # Webcam input
133
+ cap = cv2.VideoCapture(0)
134
+
135
+ # Prepare output video
136
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
137
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
138
+ fps = cap.get(cv2.CAP_PROP_FPS) or 30
139
+ self.fps = fps
140
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
141
+ out = cv2.VideoWriter('output_detection.mp4', fourcc, fps, (width, height))
142
+
143
+ # Processing loop
144
+ frame_info_list = []
145
+ while cap.isOpened():
146
+ ret, frame = cap.read()
147
+ if not ret:
148
+ break
149
+
150
+ # Process frame
151
+ processed_frame, frame_info = self.process_frame(frame)
152
+ out.write(processed_frame)
153
+ frame_info_list.append(frame_info)
154
 
155
+ # Release resources
156
+ cap.release()
157
+ out.release()
158
+
159
+ return 'output_detection.mp4', frame_info_list
160
+
161
+ # Create Gradio interface
162
+ detector = ObjectDetector()
163
+
164
+ def detect_interface(input_media):
165
  """
166
+ Gradio interface function for detection
167
  """
168
+ output_video, frame_info_list = detector.process_input(input_media)
 
 
 
 
 
 
169
 
170
+ # Generate text summary
171
+ summary = "Detection Summary:\n"
172
+ if frame_info_list:
173
+ # Take the last frame's information
174
+ last_frame_info = frame_info_list[-1]
175
+ summary += f"Total Persons Detected: {last_frame_info['person_count']}\n"
176
+ summary += f"Running Status: {last_frame_info['running_status']}\n"
177
+ if last_frame_info['prolonged_stay_ids']:
178
+ summary += f"Prolonged Stay Detected - Person IDs: {last_frame_info['prolonged_stay_ids']}"
179
+ else:
180
+ summary += "No Prolonged Stay Detected"
181
 
182
+ return output_video, summary
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
 
184
+ # Gradio Interface
185
  iface = gr.Interface(
186
+ fn=detect_interface,
187
  inputs=[
188
+ gr.File(label="Upload Video", type="filepath"),
189
+ gr.Webcam(label="Or Use Webcam")
 
190
  ],
191
  outputs=[
192
+ gr.Video(label="Processed Video"),
193
+ gr.Textbox(label="Detection Summary")
194
  ],
195
+ title="Object Detection with Tracking",
196
+ description="Upload a video or use webcam for real-time object detection and tracking"
197
  )
198
 
199
  # Launch the interface