TreezzZ commited on
Commit
24a683e
·
verified ·
1 Parent(s): 24e29f0

Upload utils.py

Browse files
Files changed (1) hide show
  1. utils.py +250 -0
utils.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # utils.py - Helper functions for insect detection demo
2
+
3
+ import cv2
4
+ import numpy as np
5
+ import matplotlib.pyplot as plt
6
+ from typing import List, Dict, Tuple, Optional
7
+ from ultralytics import YOLO
8
+
9
+ def perform_detection(
10
+ yolo_model: YOLO,
11
+ frame: np.ndarray,
12
+ conf_threshold: float=0.5
13
+ ) -> Optional[List[Dict]]:
14
+ """
15
+ Runs the YOLO model inference on a single frame.
16
+ """
17
+ if frame is None:
18
+ print("Error: Input frame is None in perform_detection.")
19
+ return None
20
+ try:
21
+ # Perform inference using the model
22
+ results = yolo_model.predict(source=frame, conf=conf_threshold, verbose=False)
23
+ return results
24
+ except Exception as e:
25
+ print(f"Error during model prediction: {e}")
26
+ return None
27
+
28
+ def create_motion_mask(frame, threshold=25):
29
+ """
30
+ Creates a simple motion mask from an image.
31
+ For the demo, we'll use a basic thresholding approach.
32
+ """
33
+ # Convert to grayscale
34
+ gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
35
+
36
+ # Apply Gaussian blur to reduce noise
37
+ blurred = cv2.GaussianBlur(gray, (5, 5), 0)
38
+
39
+ # Apply adaptive thresholding
40
+ thresh = cv2.adaptiveThreshold(
41
+ blurred, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
42
+ cv2.THRESH_BINARY_INV, 11, threshold
43
+ )
44
+
45
+ # Apply morphological operations to clean up the mask
46
+ kernel = np.ones((3, 3), np.uint8)
47
+ mask = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, kernel, iterations=1)
48
+ mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel, iterations=2)
49
+
50
+ return mask
51
+
52
+ def postprocess_results(
53
+ results: Optional[List[Dict]],
54
+ model_class_names: Dict[int, str],
55
+ mask: Optional[np.ndarray] = None
56
+ ) -> List[Dict]:
57
+ """
58
+ Extracts information from detection results.
59
+ If a mask is provided, only keeps detections that overlap with the mask.
60
+ """
61
+ detections_list = []
62
+ if results is None or not results:
63
+ return detections_list
64
+
65
+ try:
66
+ boxes = results[0].boxes
67
+ except (IndexError, AttributeError) as e:
68
+ print(f"Warning: Could not access boxes in results: {e}")
69
+ return detections_list
70
+
71
+ for box in boxes:
72
+ try:
73
+ # Extract bounding box coordinates (xyxy format)
74
+ xyxy = box.xyxy[0].cpu().numpy().astype(int)
75
+ x1, y1, x2, y2 = xyxy
76
+
77
+ # If we have a mask, check if this detection overlaps with it
78
+ if mask is not None:
79
+ # Check if the center of the bounding box is within the mask
80
+ center_x = (x1 + x2) // 2
81
+ center_y = (y1 + y2) // 2
82
+
83
+ # Also check if a significant portion of the box overlaps with the mask
84
+ # First make sure we stay within mask boundaries
85
+ y1_safe = max(0, min(y1, mask.shape[0]-1))
86
+ y2_safe = max(0, min(y2, mask.shape[0]-1))
87
+ x1_safe = max(0, min(x1, mask.shape[1]-1))
88
+ x2_safe = max(0, min(x2, mask.shape[1]-1))
89
+
90
+ # Extract the region of the mask corresponding to the bounding box
91
+ box_region = mask[y1_safe:y2_safe, x1_safe:x2_safe]
92
+
93
+ # Calculate overlap
94
+ if box_region.size > 0:
95
+ mask_coverage = np.sum(box_region > 0) / box_region.size
96
+ else:
97
+ mask_coverage = 0
98
+
99
+ # Skip this detection if it doesn't overlap with the mask
100
+ if not (0 <= center_y < mask.shape[0] and 0 <= center_x < mask.shape[1] and
101
+ (mask[center_y, center_x] > 0 or mask_coverage > 0.5)):
102
+ continue
103
+
104
+ # Extract confidence score
105
+ conf = float(box.conf[0].cpu().numpy())
106
+
107
+ # Extract class ID and map to class name
108
+ cls_id = int(box.cls[0].cpu().numpy())
109
+ class_name = model_class_names.get(cls_id, f"Unknown Class {cls_id}")
110
+
111
+ # Store detection info
112
+ detections_list.append({
113
+ 'class_name': class_name,
114
+ 'confidence': conf,
115
+ 'bbox_xyxy': [x1, y1, x2, y2]
116
+ })
117
+ except Exception as e:
118
+ print(f"Error processing a detection box: {e}")
119
+ continue
120
+ return detections_list
121
+
122
+ def draw_detections(
123
+ frame: np.ndarray,
124
+ detections: List[Dict],
125
+ mask: Optional[np.ndarray] = None
126
+ ) -> np.ndarray:
127
+ """
128
+ Draws bounding boxes and labels on the frame.
129
+ If mask is provided, overlays it on the frame.
130
+ """
131
+ output_frame = frame.copy()
132
+
133
+ # If we have a mask, overlay it with transparency
134
+ if mask is not None and mask.shape[0] > 0 and mask.shape[1] > 0:
135
+ # Create a colored mask for overlay
136
+ mask_overlay = np.zeros_like(output_frame)
137
+ mask_overlay[mask > 0] = [0, 100, 0] # Green tint for mask regions
138
+
139
+ # Blend mask with the frame
140
+ output_frame = cv2.addWeighted(output_frame, 0.7, mask_overlay, 0.3, 0)
141
+
142
+ # Draw bounding boxes
143
+ color = (0, 255, 0) # Green color for bounding box
144
+ font_scale = 1.2
145
+ font = cv2.FONT_HERSHEY_SIMPLEX
146
+ for detection in detections:
147
+ try:
148
+ x1, y1, x2, y2 = detection['bbox_xyxy']
149
+ class_name = detection['class_name']
150
+ conf = detection['confidence']
151
+
152
+ # Draw Bounding Box
153
+ cv2.rectangle(output_frame, (x1, y1), (x2, y2), color, 5)
154
+
155
+ # Prepare and Draw Label
156
+ label = f"{class_name}: {conf:.2f}"
157
+
158
+ # Calculate text size for background
159
+ (label_width, label_height), baseline = cv2.getTextSize(label, font, font_scale, 3)
160
+ label_ymin = max(y1, label_height + 10)
161
+
162
+ # Draw background for text
163
+ cv2.rectangle(output_frame,
164
+ (x1, label_ymin - label_height - 10),
165
+ (x1 + label_width, label_ymin - baseline),
166
+ color,
167
+ cv2.FILLED)
168
+
169
+ # Add text
170
+ cv2.putText(output_frame,
171
+ label,
172
+ (x1, label_ymin - 5),
173
+ font,
174
+ font_scale,
175
+ (255, 255, 255), # White color
176
+ 3)
177
+ except Exception as e:
178
+ continue
179
+ return output_frame
180
+
181
+ def load_yolo_model(model_path):
182
+ """
183
+ Loads the YOLO model from the specified path.
184
+ """
185
+ print("Loading the YOLO model...")
186
+ try:
187
+ model = YOLO(model_path)
188
+ class_names = model.names
189
+ print(f"Model loaded with {len(class_names)} classes!")
190
+ return model, class_names
191
+ except Exception as e:
192
+ print(f"Error loading model: {e}")
193
+ return None, None
194
+
195
+ def load_image(image_path):
196
+ """
197
+ Loads an image from the specified path.
198
+ """
199
+ print(f"Opening image: {image_path}")
200
+ image = cv2.imread(image_path)
201
+ if image is None:
202
+ print(f"Error: Could not read image file '{image_path}'.")
203
+ return image
204
+
205
+ def load_or_create_mask(image, mask_path=None):
206
+ """
207
+ Either loads a mask from disk or creates a new one from the image.
208
+ """
209
+ if mask_path and os.path.exists(mask_path):
210
+ print(f"Loading mask: {mask_path}")
211
+ mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
212
+ else:
213
+ print("Creating mask from image...")
214
+ mask = create_motion_mask(image)
215
+
216
+ return mask
217
+
218
+ def display_results(output_frame, detections, mask=None):
219
+ """
220
+ Displays detection results and saves the output image.
221
+ """
222
+ # Display results in console
223
+ print("\n--- Insects Detected ---")
224
+ if detections:
225
+ for i, obj in enumerate(detections, 1):
226
+ print(f"{i}. {obj['class_name']} (confidence: {obj['confidence']:.2f})")
227
+ else:
228
+ print("No insects detected.")
229
+
230
+ # Save mask if it exists
231
+ if mask is not None:
232
+ cv2.imwrite("motion_mask.png", mask)
233
+ print("Motion mask saved to: motion_mask.png")
234
+
235
+ # Convert from BGR to RGB for display
236
+ output_rgb = cv2.cvtColor(output_frame, cv2.COLOR_BGR2RGB)
237
+
238
+ # Display the image
239
+ plt.figure(figsize=(10, 8))
240
+ plt.imshow(output_rgb)
241
+ plt.title("Insect Detection Results")
242
+ plt.axis('off')
243
+ plt.show()
244
+
245
+ # Save result
246
+ result_path = "detection_result.jpg"
247
+ cv2.imwrite(result_path, output_frame)
248
+ print(f"Result saved to: {result_path}")
249
+
250
+ import os # Added for file path operations