Abs6187 commited on
Commit
47ee765
1 Parent(s): f4defdf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +97 -38
app.py CHANGED
@@ -1,51 +1,110 @@
1
- # app.py
2
  import gradio as gr
3
  import cv2
4
  import numpy as np
5
- from model import SuspiciousActivityModel # Import the model
 
 
6
 
7
- # Initialize the model paths
8
- lstm_model_path = 'suspicious_activity_model.h5' # Path to your LSTM model
9
- yolo_model_path = 'yolov8n-pose.pt' # Path to your YOLO model
 
10
 
11
- # Initialize the suspicious activity model
12
- model = SuspiciousActivityModel(lstm_model_path, yolo_model_path)
 
 
 
 
 
 
 
 
 
 
13
 
14
- # Function to process video frame
15
- def process_video(video_frame):
16
- # Check if the input frame is a valid NumPy array
17
- if isinstance(video_frame, np.ndarray):
18
- print(f"Frame shape: {video_frame.shape}") # Print the shape of the frame for debugging
 
 
 
 
19
 
20
- # Convert frame from BGR to RGB (OpenCV uses BGR by default)
21
- try:
22
- frame_rgb = cv2.cvtColor(video_frame, cv2.COLOR_BGR2RGB)
23
- except cv2.error as e:
24
- print(f"Error in cvtColor: {e}")
25
- return video_frame # Return the original frame if error occurs
26
-
27
- # Call model to detect activity in the frame
28
- label = model.detect_activity(frame_rgb)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
- # Add label to the frame (Optional: you can also draw bounding boxes)
31
- cv2.putText(frame_rgb, label, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2, cv2.LINE_AA)
32
 
33
- # Convert back to BGR for Gradio (since it expects BGR format)
34
- frame_bgr = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2BGR)
35
- return frame_bgr
36
- else:
37
- print("Received invalid frame format")
38
- return video_frame # Return the original frame if it's not valid
 
 
39
 
40
- # Gradio interface
41
  iface = gr.Interface(
42
- fn=process_video, # Function that processes each frame
43
- inputs=gr.Video(type="webcam", streaming=True), # Use webcam as input
44
- outputs=gr.Video(), # Output is also a video
45
- live=True, # Stream the video in real time
46
- title="Suspicious Activity Detection" # Interface title
47
  )
48
 
49
- # Launch the app with public link
50
- if __name__ == "__main__":
51
- iface.launch(share=True) # Set share=True to create a public link
 
 
1
  import gradio as gr
2
  import cv2
3
  import numpy as np
4
+ from tensorflow.keras.models import load_model
5
+ from sklearn.preprocessing import StandardScaler
6
+ from ultralytics import YOLO
7
 
8
+ # Load models
9
+ lstm_model = load_model('suspicious_activity_model.h5')
10
+ yolo_model = YOLO('yolov8n-pose.pt') # Ensure this model supports keypoint detection
11
+ scaler = StandardScaler()
12
 
13
+ def extract_keypoints(frame):
14
+ """
15
+ Extracts normalized keypoints from a frame using YOLO pose model.
16
+ """
17
+ results = yolo_model(frame, verbose=False)
18
+ for r in results:
19
+ if r.keypoints is not None and len(r.keypoints) > 0:
20
+ # Extract the first detected person's keypoints
21
+ keypoints = r.keypoints.xyn.tolist()[0] # Use the first person's keypoints
22
+ flattened_keypoints = [kp for keypoint in keypoints for kp in keypoint[:2]] # Flatten x, y values
23
+ return flattened_keypoints
24
+ return None # Return None if no keypoints are detected
25
 
26
+ def process_frame(frame):
27
+ """
28
+ Process each frame for suspicious activity detection
29
+ """
30
+ # Perform YOLO detection
31
+ results = yolo_model(frame, verbose=False)
32
+ for box in results[0].boxes:
33
+ cls = int(box.cls[0]) # Class ID
34
+ confidence = float(box.conf[0])
35
 
36
+ # Detect persons only (class_id 0 for 'person')
37
+ if cls == 0 and confidence > 0.5:
38
+ x1, y1, x2, y2 = map(int, box.xyxy[0]) # Bounding box coordinates
39
+
40
+ # Extract ROI for classification
41
+ roi = frame[y1:y2, x1:x2]
42
+ if roi.size > 0:
43
+ # Preprocess ROI to extract keypoints
44
+ keypoints = extract_keypoints(roi)
45
+
46
+ if keypoints is not None and len(keypoints) > 0:
47
+ # Standardize and reshape keypoints for LSTM input
48
+ keypoints_scaled = scaler.fit_transform([keypoints]) # Standardize features
49
+ keypoints_reshaped = keypoints_scaled.reshape((1, 1, len(keypoints))) # Reshape for LSTM
50
+
51
+ # Predict with LSTM model
52
+ prediction = (lstm_model.predict(keypoints_reshaped) > 0.5).astype(int)[0][0]
53
+
54
+ # Draw bounding box and label
55
+ color = (0, 0, 255) if prediction == 1 else (0, 255, 0)
56
+ label = 'Suspicious' if prediction == 1 else 'Normal'
57
+ cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
58
+ cv2.putText(frame, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
59
+ else:
60
+ print("No valid keypoints detected for ROI. Skipping frame.")
61
+ else:
62
+ print("ROI size is zero. Skipping frame.")
63
+
64
+ return frame
65
+
66
+ def detect_suspicious_activity(input_video):
67
+ """
68
+ Main function to process video for suspicious activity detection
69
+ """
70
+ # Open video capture
71
+ cap = cv2.VideoCapture(input_video)
72
+
73
+ # Prepare to save output video
74
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
75
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
76
+ fps = cap.get(cv2.CAP_PROP_FPS)
77
+
78
+ # Create VideoWriter object
79
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
80
+ out = cv2.VideoWriter('output_video.mp4', fourcc, fps, (width, height))
81
+
82
+ # Process each frame
83
+ while cap.isOpened():
84
+ ret, frame = cap.read()
85
+ if not ret:
86
+ break
87
 
88
+ # Process and annotate frame
89
+ processed_frame = process_frame(frame)
90
 
91
+ # Write processed frame to output video
92
+ out.write(processed_frame)
93
+
94
+ # Release resources
95
+ cap.release()
96
+ out.release()
97
+
98
+ return 'output_video.mp4'
99
 
100
+ # Create Gradio interface
101
  iface = gr.Interface(
102
+ fn=detect_suspicious_activity,
103
+ inputs=gr.Video(label="Upload Video"),
104
+ outputs=gr.Video(label="Processed Video"),
105
+ title="Suspicious Activity Detection",
106
+ description="Upload a video to detect suspicious activities using YOLO and LSTM models"
107
  )
108
 
109
+ # Launch the interface
110
+ iface.launch()