mlbench123 commited on
Commit
39b7b21
·
verified ·
1 Parent(s): 7b89288

Upload 4 files

Browse files
Files changed (4) hide show
  1. GMM.py +949 -0
  2. app_s_a_LiveCam.py +1157 -0
  3. requirements.txt +29 -0
  4. send_discord.py +172 -0
GMM.py ADDED
@@ -0,0 +1,949 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2 as cv
3
+ import os
4
+ from numpy.linalg import norm, inv
5
+ from scipy.stats import multivariate_normal as mv_norm
6
+ import joblib # or import pickle
7
+ import os
8
+ import torch
9
+ from torch.distributions import MultivariateNormal
10
+ import torch.nn.functional as F
11
+ init_weight = [0.7, 0.11, 0.1, 0.09]
12
+ init_u = np.zeros(3)
13
+ # initial Covariance matrix
14
+ init_sigma = 225*np.eye(3)
15
+ init_alpha = 0.05
16
+
17
+ class GMM():
18
+ def __init__(self, data_dir, train_num, alpha=init_alpha):
19
+ self.data_dir = data_dir
20
+ self.train_num = train_num
21
+ self.alpha = alpha
22
+ self.img_shape = None
23
+
24
+ self.weight = None
25
+ self.mu = None
26
+ self.sigma = None
27
+ self.K = None
28
+ self.B = None
29
+
30
+ def check(self, pixel, mu, sigma):
31
+ '''
32
+ Check whether a pixel matches a Gaussian distribution.
33
+ Matching means the Mahalanobis distance is less than 2.5.
34
+ '''
35
+ # Convert to torch tensors on same device
36
+ if isinstance(mu, np.ndarray):
37
+ mu = torch.from_numpy(mu).float()
38
+ if isinstance(sigma, np.ndarray):
39
+ sigma = torch.from_numpy(sigma).float()
40
+ if isinstance(pixel, np.ndarray):
41
+ pixel = torch.from_numpy(pixel).float()
42
+
43
+ # Ensure all are on the same device
44
+ device = mu.device
45
+ pixel = pixel.to(device)
46
+ sigma = sigma.to(device)
47
+
48
+ # Compute Mahalanobis distance
49
+ delta = pixel - mu
50
+ sigma_inv = torch.linalg.inv(sigma)
51
+ d_squared = delta @ sigma_inv @ delta
52
+ d = torch.sqrt(d_squared + 1e-5)
53
+
54
+ return d.item() < 0.1
55
+
56
+ def rgba_to_rgb_for_processing(image_path):
57
+ img = cv.imread(image_path, cv.IMREAD_UNCHANGED)
58
+
59
+ if img.shape[2] == 4: # RGBA
60
+ # Create white background
61
+ rgb_img = np.ones((img.shape[0], img.shape[1], 3), dtype=np.uint8) * 255
62
+
63
+ # Alpha blending: blend with white background
64
+ alpha = img[:, :, 3:4] / 255.0
65
+ rgb_img = rgb_img * (1 - alpha) + img[:, :, :3] * alpha
66
+
67
+ return rgb_img.astype(np.uint8)
68
+ else:
69
+ return img
70
+
71
+
72
+ def train(self, K=4):
73
+ '''
74
+ train model with GPU acceleration
75
+ '''
76
+ self.K = K
77
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
78
+ print(f"Using device: {device}")
79
+
80
+ file_list = []
81
+ for i in range(self.train_num):
82
+ file_name = os.path.join(self.data_dir, 'b%05d' % i + '.png')
83
+ file_list.append(file_name)
84
+
85
+ # Initialize with first image
86
+ img_init = cv.imread(file_list[0])
87
+ img_shape = img_shape = img_init.shape
88
+ self.img_shape = img_shape
89
+ height, width, channels = img_shape
90
+
91
+ # Initialize model parameters on GPU
92
+ self.weight = torch.full((height, width, K), 1.0/K,
93
+ dtype=torch.float32, device=device)
94
+ self.mu = torch.zeros(height, width, K, 3,
95
+ dtype=torch.float32, device=device)
96
+ self.sigma = torch.zeros(height, width, K, 3, 3,
97
+ dtype=torch.float32, device=device)
98
+ self.B = torch.ones((height, width),
99
+ dtype=torch.int32, device=device)
100
+
101
+ # Initialize mu with first image values
102
+ img_tensor = torch.from_numpy(img_init).float().to(device)
103
+ for k in range(K):
104
+ self.mu[:, :, k, :] = img_tensor
105
+
106
+ # Initialize sigma with identity matrix * 225
107
+ self.sigma[:] = torch.eye(3, device=device) * 225
108
+
109
+ # Training loop
110
+ for file in file_list:
111
+ print('training:{}'.format(file))
112
+ img = cv.imread(file)
113
+ img_tensor = torch.from_numpy(img).float().to(device) # (H,W,3)
114
+
115
+ # Check matches for all pixels
116
+ matches = torch.full((height, width), -1, dtype=torch.long, device=device)
117
+
118
+ for k in range(K):
119
+ # Calculate Mahalanobis distance for each distribution
120
+ delta = img_tensor.unsqueeze(2) - self.mu # (H,W,K,3)
121
+ sigma_inv = torch.linalg.inv(self.sigma) # (H,W,K,3,3)
122
+
123
+ # Compute (x-μ)T Σ^-1 (x-μ)
124
+ temp = torch.einsum('hwki,hwkij->hwkj', delta, sigma_inv)
125
+ mahalanobis = torch.sqrt(torch.einsum('hwki,hwki->hwk', temp, delta))
126
+
127
+ # Update matches where distance < 2.5 and not already matched
128
+ match_mask = (mahalanobis[:,:,k] < 2.5) & (matches == -1)
129
+ matches[match_mask] = k
130
+
131
+ # Process matched pixels
132
+ for k in range(K):
133
+ # Get mask for current distribution matches
134
+ mask = matches == k
135
+ if mask.any():
136
+ # Get matched pixels
137
+ matched_pixels = img_tensor[mask] # (N,3)
138
+ matched_mu = self.mu[:,:,k,:][mask] # (N,3)
139
+ matched_sigma = self.sigma[:,:,k,:,:][mask] # (N,3,3)
140
+
141
+ try:
142
+ # Create multivariate normal distribution
143
+ mvn = MultivariateNormal(matched_mu,
144
+ covariance_matrix=matched_sigma)
145
+
146
+ # Calculate rho
147
+ rho = self.alpha * torch.exp(mvn.log_prob(matched_pixels))
148
+
149
+ # Update weights
150
+ self.weight[:,:,k][mask] = (1 - self.alpha) * self.weight[:,:,k][mask] + self.alpha
151
+
152
+ # Update mu
153
+ delta = matched_pixels - matched_mu
154
+ self.mu[:,:,k,:][mask] += rho.unsqueeze(1) * delta
155
+
156
+ # Update sigma
157
+ delta_outer = torch.einsum('bi,bj->bij', delta, delta)
158
+ sigma_update = rho.unsqueeze(1).unsqueeze(2) * (delta_outer - matched_sigma)
159
+ self.sigma[:,:,k,:,:][mask] += sigma_update
160
+
161
+ except RuntimeError as e:
162
+ print(f"Error updating distribution {k}: {e}")
163
+ continue
164
+
165
+ # Process non-matched pixels
166
+ non_matched = matches == -1
167
+ if non_matched.any():
168
+ # Find least probable distribution for each non-matched pixel
169
+ weight_non_matched = self.weight[non_matched] # shape: (N, K)
170
+ min_weight_idx = torch.argmin(weight_non_matched, dim=1) # shape: (N,)
171
+
172
+ # Create flat indices of non-matched pixels
173
+ non_matched_indices = non_matched.nonzero(as_tuple=False) # shape: (N, 2)
174
+
175
+ for k in range(K):
176
+ # Find positions where min_weight_idx == k
177
+ k_mask = (min_weight_idx == k)
178
+ if k_mask.any():
179
+ selected_indices = non_matched_indices[k_mask] # shape: (M, 2)
180
+ y_idx = selected_indices[:, 0]
181
+ x_idx = selected_indices[:, 1]
182
+
183
+ # Update mu and sigma
184
+ self.mu[y_idx, x_idx, k, :] = img_tensor[y_idx, x_idx]
185
+ self.sigma[y_idx, x_idx, k, :, :] = torch.eye(3, device=device) * 225
186
+
187
+ # Convert to numpy for reordering and debug prints
188
+ weight_np = self.weight.cpu().numpy()
189
+ mu_np = self.mu.cpu().numpy()
190
+ sigma_np = self.sigma.cpu().numpy()
191
+ B_np = self.B.cpu().numpy()
192
+
193
+ print('img:{}'.format(img[100][100]))
194
+ print('weight:{}'.format(weight_np[100][100]))
195
+
196
+ # Update numpy arrays for reorder
197
+ self.weight = weight_np
198
+ self.mu = mu_np
199
+ self.sigma = sigma_np
200
+ self.B = B_np
201
+
202
+ self.reorder()
203
+ for i in range(self.K):
204
+ print('u:{}'.format(self.mu[100][100][i]))
205
+
206
+ # Move back to GPU for next iteration
207
+ self.weight = torch.from_numpy(self.weight).to(device)
208
+ self.mu = torch.from_numpy(self.mu).to(device)
209
+ self.sigma = torch.from_numpy(self.sigma).to(device)
210
+ self.B = torch.from_numpy(self.B).to(device)
211
+
212
+ def save_model(self, file_path):
213
+ """
214
+ Save the trained model to a file
215
+ """
216
+ # Only make directories if there is a directory in the path
217
+ dir_name = os.path.dirname(file_path)
218
+ if dir_name:
219
+ os.makedirs(dir_name, exist_ok=True)
220
+
221
+ joblib.dump({
222
+ 'weight': self.weight,
223
+ 'mu': self.mu,
224
+ 'sigma': self.sigma,
225
+ 'K': self.K,
226
+ 'B': self.B,
227
+ 'img_shape': self.img_shape,
228
+ 'alpha': self.alpha,
229
+ 'data_dir': self.data_dir,
230
+ 'train_num': self.train_num
231
+ }, file_path)
232
+
233
+ print(f"Model saved to {file_path}")
234
+
235
+ @classmethod
236
+ def load_model(cls, file_path):
237
+ """
238
+ Load a trained model from file
239
+ """
240
+ data = joblib.load(file_path)
241
+
242
+ # Create new instance
243
+ gmm = cls(data['data_dir'], data['train_num'], data['alpha'])
244
+
245
+ # Restore all attributes
246
+ gmm.weight = data['weight']
247
+ gmm.mu = data['mu']
248
+ gmm.sigma = data['sigma']
249
+ gmm.K = data['K']
250
+ gmm.B = data['B']
251
+ gmm.img_shape = data['img_shape']
252
+ gmm.image_shape = data['img_shape']
253
+
254
+ print(f"Model loaded from {file_path}")
255
+ return gmm
256
+ # @classmethod
257
+ # def load_model(cls, file_path):
258
+ # """
259
+ # Load a trained model safely onto CPU, even if saved from GPU.
260
+ # """
261
+ # import pickle
262
+
263
+ # def cpu_load(path):
264
+ # with open(path, "rb") as f:
265
+ # unpickler = pickle._Unpickler(f)
266
+ # unpickler.persistent_load = lambda saved_id: torch.load(saved_id, map_location="cpu")
267
+ # return unpickler.load()
268
+
269
+ # # Force joblib to use pickle with CPU-mapped tensors
270
+ # data = cpu_load(file_path)
271
+
272
+ # # Create instance
273
+ # gmm = cls(data['data_dir'], data['train_num'], data['alpha'])
274
+
275
+ # Assign all attributes (already CPU tensors now)
276
+ gmm.weight = data['weight']
277
+ gmm.mu = data['mu']
278
+ gmm.sigma = data['sigma']
279
+ gmm.K = data['K']
280
+ gmm.B = data['B']
281
+ gmm.img_shape = data['img_shape']
282
+ gmm.image_shape = data['img_shape']
283
+
284
+ print(f"✅ GMM model loaded on CPU from {file_path}")
285
+ return gmm
286
+
287
+
288
+
289
+
290
+ def reorder(self, T=0.90):
291
+ '''
292
+ Reorder the estimated components based on the ratio pi / the norm of standard deviation.
293
+ The first B components are chosen as background components.
294
+ The default threshold is 0.90.
295
+ '''
296
+ epsilon = 1e-6 # to prevent divide-by-zero
297
+
298
+ for i in range(self.img_shape[0]):
299
+ for j in range(self.img_shape[1]):
300
+ k_weight = self.weight[i][j]
301
+ k_norm = []
302
+
303
+ for k in range(self.K):
304
+ cov = self.sigma[i][j][k]
305
+ try:
306
+ if np.all(np.linalg.eigvals(cov) >= 0):
307
+ # stddev = np.sqrt(cov)
308
+ epsilon = 1e-6
309
+ stddev = np.sqrt(np.maximum(cov, epsilon))
310
+ k_norm.append(norm(stddev))
311
+ else:
312
+ k_norm.append(epsilon)
313
+ except:
314
+ k_norm.append(epsilon)
315
+
316
+ k_norm = np.array(k_norm)
317
+ ratio = k_weight / (k_norm + epsilon)
318
+ descending_order = np.argsort(-ratio)
319
+
320
+ self.weight[i][j] = self.weight[i][j][descending_order]
321
+ self.mu[i][j] = self.mu[i][j][descending_order]
322
+ self.sigma[i][j] = self.sigma[i][j][descending_order]
323
+
324
+ cum_weight = 0
325
+ for index, order in enumerate(descending_order):
326
+ cum_weight += self.weight[i][j][index]
327
+ if cum_weight > T:
328
+ self.B[i][j] = index + 1
329
+ break
330
+ from typing import Tuple, Optional
331
+
332
+ def region_propfill_enhancement(self, binary_mask: np.ndarray,
333
+ table_mask: Optional[np.ndarray] = None, # ADDED parameter
334
+ dilation_kernel_size: int = 5,
335
+ dilation_iterations: int = 2,
336
+ erosion_iterations: int = 1,
337
+ fill_threshold: int = 200,
338
+ min_contour_area: int = 50) -> Tuple[np.ndarray, np.ndarray]:
339
+ """
340
+ Enhance GMM binary prediction mask using dilation and region filling.
341
+
342
+ Args:
343
+ binary_mask: Binary mask from GMM detection (True for detected foreground)
344
+ table_mask: Optional binary mask defining table area (restricts processing)
345
+ dilation_kernel_size: Size of dilation kernel (odd number)
346
+ dilation_iterations: Number of dilation iterations to connect fragments
347
+ erosion_iterations: Number of erosion iterations to restore original size
348
+ fill_threshold: Threshold for flood fill operation
349
+ min_contour_area: Minimum contour area to consider for processing
350
+
351
+ Returns:
352
+ enhanced_mask: Improved binary mask with filled regions
353
+ debug_info: Dictionary containing intermediate results for debugging
354
+ """
355
+
356
+ # Convert boolean mask to uint8 if needed
357
+ if binary_mask.dtype == bool:
358
+ mask_uint8 = (binary_mask * 255).astype(np.uint8)
359
+ else:
360
+ mask_uint8 = binary_mask.astype(np.uint8)
361
+
362
+ # Apply table mask if provided - CRITICAL FIX
363
+ if table_mask is not None:
364
+ # Ensure table_mask matches dimensions
365
+ if table_mask.shape != mask_uint8.shape:
366
+ table_mask = cv.resize(table_mask.astype(np.uint8),
367
+ (mask_uint8.shape[1], mask_uint8.shape[0]),
368
+ interpolation=cv.INTER_NEAREST) > 0
369
+ # Zero out everything outside table area
370
+ mask_uint8[~table_mask] = 0
371
+
372
+ # Store original for comparison
373
+ original_mask = mask_uint8.copy()
374
+
375
+ # Step 1: Apply dilation to connect fragmented detections
376
+ kernel = cv.getStructuringElement(cv.MORPH_ELLIPSE,
377
+ (dilation_kernel_size, dilation_kernel_size))
378
+
379
+ # Dilate to connect nearby fragments
380
+ dilated_mask = cv.dilate(mask_uint8, kernel, iterations=dilation_iterations)
381
+
382
+ # Step 2: Apply flood fill to fill internal holes
383
+ filled_mask = dilated_mask.copy()
384
+ h, w = filled_mask.shape
385
+
386
+ # Create flood fill mask (needs to be 2 pixels larger)
387
+ flood_mask = np.zeros((h + 2, w + 2), np.uint8)
388
+
389
+ # Find contours to identify individual objects
390
+ contours, _ = cv.findContours(dilated_mask, cv.RETR_EXTERNAL, cv.CHAIN_APPROX_SIMPLE)
391
+
392
+ # Process each contour separately
393
+ enhanced_mask = np.zeros_like(filled_mask)
394
+
395
+ for contour in contours:
396
+ # Filter out small contours
397
+ if cv.contourArea(contour) < min_contour_area:
398
+ continue
399
+
400
+ # Create mask for this contour
401
+ contour_mask = np.zeros_like(filled_mask)
402
+ cv.drawContours(contour_mask, [contour], -1, 255, -1)
403
+
404
+ # Get bounding rectangle
405
+ x, y, w_rect, h_rect = cv.boundingRect(contour)
406
+
407
+ # Create region of interest
408
+ roi = contour_mask[y:y+h_rect, x:x+w_rect].copy()
409
+
410
+ if roi.size == 0:
411
+ continue
412
+
413
+ # Apply flood fill from borders to fill external areas
414
+ roi_filled = roi.copy()
415
+ roi_h, roi_w = roi_filled.shape
416
+
417
+ # Create flood mask for ROI
418
+ roi_flood_mask = np.zeros((roi_h + 2, roi_w + 2), np.uint8)
419
+
420
+ # Flood fill from all border points to mark external areas
421
+ border_points = []
422
+ # Top and bottom borders
423
+ for i in range(roi_w):
424
+ if roi_filled[0, i] == 0:
425
+ border_points.append((i, 0))
426
+ if roi_filled[roi_h-1, i] == 0:
427
+ border_points.append((i, roi_h-1))
428
+
429
+ # Left and right borders
430
+ for i in range(roi_h):
431
+ if roi_filled[i, 0] == 0:
432
+ border_points.append((0, i))
433
+ if roi_filled[i, roi_w-1] == 0:
434
+ border_points.append((roi_w-1, i))
435
+
436
+ # Apply flood fill from border points
437
+ external_mask = np.zeros_like(roi_filled)
438
+ for point in border_points:
439
+ if roi_filled[point[1], point[0]] == 0:
440
+ cv.floodFill(external_mask, roi_flood_mask, point, 255)
441
+
442
+ # Invert to get internal areas
443
+ internal_mask = cv.bitwise_not(external_mask)
444
+
445
+ # Combine with original contour
446
+ filled_contour = cv.bitwise_or(roi, internal_mask)
447
+
448
+ # Place back in full image
449
+ enhanced_mask[y:y+h_rect, x:x+w_rect] = cv.bitwise_or(
450
+ enhanced_mask[y:y+h_rect, x:x+w_rect], filled_contour)
451
+
452
+ # Step 3: Optional erosion to restore approximate original size
453
+ if erosion_iterations > 0:
454
+ erosion_kernel = cv.getStructuringElement(cv.MORPH_ELLIPSE,
455
+ (dilation_kernel_size, dilation_kernel_size))
456
+ enhanced_mask = cv.erode(enhanced_mask, erosion_kernel, iterations=erosion_iterations)
457
+
458
+ # Step 4: Ensure we don't lose original detections AND respect table boundary
459
+ enhanced_mask = cv.bitwise_or(enhanced_mask, original_mask)
460
+
461
+ # RE-APPLY TABLE MASK - Ensure no processing outside table
462
+ if table_mask is not None:
463
+ enhanced_mask[~table_mask] = 0
464
+
465
+ # Convert back to boolean if input was boolean
466
+ if binary_mask.dtype == bool:
467
+ enhanced_mask = enhanced_mask > 0
468
+
469
+ # Create debug info
470
+ debug_info = {
471
+ 'original_mask': original_mask,
472
+ 'dilated_mask': dilated_mask,
473
+ 'enhanced_mask': enhanced_mask,
474
+ 'num_contours_processed': len([c for c in contours if cv.contourArea(c) >= min_contour_area])
475
+ }
476
+
477
+ return enhanced_mask, debug_info
478
+
479
+ def draw_heatmap_colorbar(self, frame: np.ndarray, heatmap: np.ndarray) -> np.ndarray:
480
+ """
481
+ Draw a vertical heatmap color bar on the right side of the frame.
482
+
483
+ Args:
484
+ frame: Original frame
485
+ heatmap: Heatmap array with values 0-1
486
+
487
+ Returns:
488
+ Frame with color bar overlay
489
+ """
490
+ height, width = frame.shape[:2]
491
+
492
+ # Color bar dimensions
493
+ bar_width = 30
494
+ bar_height = int(height * 0.6)
495
+ bar_x = width - bar_width - 20
496
+ bar_y = int(height * 0.2)
497
+
498
+ # Create gradient color bar
499
+ gradient = np.linspace(1, 0, bar_height).reshape(-1, 1)
500
+ gradient = np.tile(gradient, (1, bar_width))
501
+
502
+ # Convert to color using JET colormap
503
+ gradient_colored = cv.applyColorMap((gradient * 255).astype(np.uint8), cv.COLORMAP_JET)
504
+
505
+ # Add border and background
506
+ cv.rectangle(frame, (bar_x - 2, bar_y - 2),
507
+ (bar_x + bar_width + 2, bar_y + bar_height + 2), (255, 255, 255), 2)
508
+ cv.rectangle(frame, (bar_x - 1, bar_y - 1),
509
+ (bar_x + bar_width + 1, bar_y + bar_height + 1), (0, 0, 0), 1)
510
+
511
+ # Place color bar
512
+ frame[bar_y:bar_y+bar_height, bar_x:bar_x+bar_width] = gradient_colored
513
+
514
+ # Add labels
515
+ labels = ["1.0", "0.75", "0.5", "0.25", "0.0"]
516
+ label_positions = [0, 0.25, 0.5, 0.75, 1.0]
517
+
518
+ for label, pos in zip(labels, label_positions):
519
+ y_pos = bar_y + int(pos * bar_height)
520
+ cv.putText(frame, label, (bar_x + bar_width + 5, y_pos + 5),
521
+ cv.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 255), 1)
522
+
523
+ # Add title
524
+ cv.putText(frame, "HEAT", (bar_x - 5, bar_y - 10),
525
+ cv.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
526
+
527
+ # Add current max value
528
+ max_heat = heatmap.max()
529
+ cv.putText(frame, f"Max: {max_heat:.2f}", (bar_x - 20, bar_y + bar_height + 20),
530
+ cv.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 255), 1)
531
+
532
+ return frame
533
+
534
+ def region_propfill_enhancement(self, binary_mask: np.ndarray,
535
+ table_mask: Optional[np.ndarray] = None, # ADDED parameter
536
+ dilation_kernel_size: int = 5,
537
+ dilation_iterations: int = 2,
538
+ erosion_iterations: int = 1,
539
+ fill_threshold: int = 200,
540
+ min_contour_area: int = 50) -> Tuple[np.ndarray, np.ndarray]:
541
+ """
542
+ Enhance GMM binary prediction mask using dilation and region filling.
543
+
544
+ Args:
545
+ binary_mask: Binary mask from GMM detection (True for detected foreground)
546
+ table_mask: Optional binary mask defining table area (restricts processing)
547
+ dilation_kernel_size: Size of dilation kernel (odd number)
548
+ dilation_iterations: Number of dilation iterations to connect fragments
549
+ erosion_iterations: Number of erosion iterations to restore original size
550
+ fill_threshold: Threshold for flood fill operation
551
+ min_contour_area: Minimum contour area to consider for processing
552
+
553
+ Returns:
554
+ enhanced_mask: Improved binary mask with filled regions
555
+ debug_info: Dictionary containing intermediate results for debugging
556
+ """
557
+
558
+ # Convert boolean mask to uint8 if needed
559
+ if binary_mask.dtype == bool:
560
+ mask_uint8 = (binary_mask * 255).astype(np.uint8)
561
+ else:
562
+ mask_uint8 = binary_mask.astype(np.uint8)
563
+
564
+ # Apply table mask if provided - CRITICAL FIX
565
+ if table_mask is not None:
566
+ # Ensure table_mask matches dimensions
567
+ if table_mask.shape != mask_uint8.shape:
568
+ table_mask = cv.resize(table_mask.astype(np.uint8),
569
+ (mask_uint8.shape[1], mask_uint8.shape[0]),
570
+ interpolation=cv.INTER_NEAREST) > 0
571
+ # Zero out everything outside table area
572
+ mask_uint8[~table_mask] = 0
573
+
574
+ # Store original for comparison
575
+ original_mask = mask_uint8.copy()
576
+
577
+ # Step 1: Apply dilation to connect fragmented detections
578
+ kernel = cv.getStructuringElement(cv.MORPH_ELLIPSE,
579
+ (dilation_kernel_size, dilation_kernel_size))
580
+
581
+ # Dilate to connect nearby fragments
582
+ dilated_mask = cv.dilate(mask_uint8, kernel, iterations=dilation_iterations)
583
+
584
+ # Step 2: Apply flood fill to fill internal holes
585
+ filled_mask = dilated_mask.copy()
586
+ h, w = filled_mask.shape
587
+
588
+ # Create flood fill mask (needs to be 2 pixels larger)
589
+ flood_mask = np.zeros((h + 2, w + 2), np.uint8)
590
+
591
+ # Find contours to identify individual objects
592
+ contours, _ = cv.findContours(dilated_mask, cv.RETR_EXTERNAL, cv.CHAIN_APPROX_SIMPLE)
593
+
594
+ # Process each contour separately
595
+ enhanced_mask = np.zeros_like(filled_mask)
596
+
597
+ for contour in contours:
598
+ # Filter out small contours
599
+ if cv.contourArea(contour) < min_contour_area:
600
+ continue
601
+
602
+ # Create mask for this contour
603
+ contour_mask = np.zeros_like(filled_mask)
604
+ cv.drawContours(contour_mask, [contour], -1, 255, -1)
605
+
606
+ # Get bounding rectangle
607
+ x, y, w_rect, h_rect = cv.boundingRect(contour)
608
+
609
+ # Create region of interest
610
+ roi = contour_mask[y:y+h_rect, x:x+w_rect].copy()
611
+
612
+ if roi.size == 0:
613
+ continue
614
+
615
+ # Apply flood fill from borders to fill external areas
616
+ roi_filled = roi.copy()
617
+ roi_h, roi_w = roi_filled.shape
618
+
619
+ # Create flood mask for ROI
620
+ roi_flood_mask = np.zeros((roi_h + 2, roi_w + 2), np.uint8)
621
+
622
+ # Flood fill from all border points to mark external areas
623
+ border_points = []
624
+ # Top and bottom borders
625
+ for i in range(roi_w):
626
+ if roi_filled[0, i] == 0:
627
+ border_points.append((i, 0))
628
+ if roi_filled[roi_h-1, i] == 0:
629
+ border_points.append((i, roi_h-1))
630
+
631
+ # Left and right borders
632
+ for i in range(roi_h):
633
+ if roi_filled[i, 0] == 0:
634
+ border_points.append((0, i))
635
+ if roi_filled[i, roi_w-1] == 0:
636
+ border_points.append((roi_w-1, i))
637
+
638
+ # Apply flood fill from border points
639
+ external_mask = np.zeros_like(roi_filled)
640
+ for point in border_points:
641
+ if roi_filled[point[1], point[0]] == 0:
642
+ cv.floodFill(external_mask, roi_flood_mask, point, 255)
643
+
644
+ # Invert to get internal areas
645
+ internal_mask = cv.bitwise_not(external_mask)
646
+
647
+ # Combine with original contour
648
+ filled_contour = cv.bitwise_or(roi, internal_mask)
649
+
650
+ # Place back in full image
651
+ enhanced_mask[y:y+h_rect, x:x+w_rect] = cv.bitwise_or(
652
+ enhanced_mask[y:y+h_rect, x:x+w_rect], filled_contour)
653
+
654
+ # Step 3: Optional erosion to restore approximate original size
655
+ if erosion_iterations > 0:
656
+ erosion_kernel = cv.getStructuringElement(cv.MORPH_ELLIPSE,
657
+ (dilation_kernel_size, dilation_kernel_size))
658
+ enhanced_mask = cv.erode(enhanced_mask, erosion_kernel, iterations=erosion_iterations)
659
+
660
+ # Step 4: Ensure we don't lose original detections AND respect table boundary
661
+ enhanced_mask = cv.bitwise_or(enhanced_mask, original_mask)
662
+
663
+ # RE-APPLY TABLE MASK - Ensure no processing outside table
664
+ if table_mask is not None:
665
+ enhanced_mask[~table_mask] = 0
666
+
667
+ # Convert back to boolean if input was boolean
668
+ if binary_mask.dtype == bool:
669
+ enhanced_mask = enhanced_mask > 0
670
+
671
+ # Create debug info
672
+ debug_info = {
673
+ 'original_mask': original_mask,
674
+ 'dilated_mask': dilated_mask,
675
+ 'enhanced_mask': enhanced_mask,
676
+ 'num_contours_processed': len([c for c in contours if cv.contourArea(c) >= min_contour_area])
677
+ }
678
+
679
+ return enhanced_mask, debug_info
680
+
681
+ def visualize_mask_enhancement(self, original_mask: np.ndarray,
682
+ enhanced_mask: np.ndarray,
683
+ debug_info: dict,
684
+ window_prefix: str = "Enhancement"):
685
+ """
686
+ Visualize the mask enhancement process.
687
+
688
+ Args:
689
+ original_mask: Original binary mask
690
+ enhanced_mask: Enhanced binary mask
691
+ debug_info: Debug information from enhancement process
692
+ window_prefix: Prefix for window names
693
+ """
694
+
695
+ # Convert boolean masks to uint8 for display
696
+ if original_mask.dtype == bool:
697
+ orig_display = (original_mask * 255).astype(np.uint8)
698
+ else:
699
+ orig_display = original_mask.astype(np.uint8)
700
+
701
+ if enhanced_mask.dtype == bool:
702
+ enhanced_display = (enhanced_mask * 255).astype(np.uint8)
703
+ else:
704
+ enhanced_display = enhanced_mask.astype(np.uint8)
705
+
706
+ # Show progression
707
+ cv.imshow(f"{window_prefix} - Original Mask", orig_display)
708
+ cv.imshow(f"{window_prefix} - Dilated Mask", debug_info['dilated_mask'])
709
+ cv.imshow(f"{window_prefix} - Enhanced Mask", enhanced_display)
710
+
711
+ # Show difference
712
+ difference = cv.absdiff(enhanced_display, orig_display)
713
+ cv.imshow(f"{window_prefix} - Added Regions", difference)
714
+
715
+ # print(f"Processed {debug_info['num_contours_processed']} contours")
716
+
717
+ def infer(self, img, heatmap=None, alpha_start=0.002, alpha_end=0.0001,
718
+ table_mask=None, cleaning_mask=None):
719
+ """
720
+ Inference with proper resizing to avoid spatial distortion:
721
+ - Preserves original aspect ratios
722
+ - Minimizes resize operations
723
+ - Ensures spatial consistency between input and output
724
+ """
725
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
726
+
727
+ # Store original dimensions
728
+ orig_H, orig_W = img.shape[:2]
729
+
730
+ # Get model's expected dimensions
731
+ model_H, model_W = self.B.shape[:2]
732
+
733
+ # Check if resizing is needed
734
+ needs_resize = (orig_H, orig_W) != (model_H, model_W)
735
+
736
+ if needs_resize:
737
+ print(f"🔧 Resizing input from ({orig_H}, {orig_W}) to model size ({model_H}, {model_W})")
738
+
739
+ # Use INTER_LINEAR for better quality, avoid INTER_NEAREST
740
+ img_resized = cv.resize(img, (model_W, model_H), interpolation=cv.INTER_LINEAR)
741
+ img_tensor = torch.from_numpy(img_resized).float().to(device)
742
+
743
+ # Process table mask with same interpolation
744
+ if table_mask is not None:
745
+ print(f"🔧 Resizing table mask from {table_mask.shape} to ({model_H}, {model_W})")
746
+ # Use INTER_NEAREST for binary masks to preserve sharp edges
747
+ table_mask_resized = cv.resize(table_mask.astype(np.uint8), (model_W, model_H),
748
+ interpolation=cv.INTER_NEAREST)
749
+ table_mask_tensor = torch.from_numpy(table_mask_resized > 0).bool().to(device)
750
+ else:
751
+ table_mask_tensor = torch.ones((model_H, model_W), dtype=torch.bool, device=device)
752
+
753
+ # Resize existing heatmap if provided
754
+ if heatmap is not None:
755
+ if heatmap.shape != (model_H, model_W):
756
+ heatmap_resized = cv.resize(heatmap, (model_W, model_H), interpolation=cv.INTER_LINEAR)
757
+ heatmap = torch.from_numpy(heatmap_resized).float().to(device)
758
+ else:
759
+ heatmap = torch.from_numpy(heatmap).float().to(device)
760
+ else:
761
+ heatmap = torch.zeros((model_H, model_W), dtype=torch.float32, device=device)
762
+
763
+ working_H, working_W = model_H, model_W
764
+
765
+ else:
766
+ # No resizing needed
767
+ img_tensor = torch.from_numpy(img).float().to(device)
768
+
769
+ if table_mask is not None:
770
+ table_mask_tensor = torch.from_numpy(table_mask > 0).bool().to(device)
771
+ else:
772
+ table_mask_tensor = torch.ones((orig_H, orig_W), dtype=torch.bool, device=device)
773
+
774
+ if heatmap is not None:
775
+ heatmap = torch.from_numpy(heatmap).float().to(device)
776
+ else:
777
+ heatmap = torch.zeros((orig_H, orig_W), dtype=torch.float32, device=device)
778
+
779
+ working_H, working_W = orig_H, orig_W
780
+
781
+ # Initialize foreground detection mask
782
+ detection_mask = table_mask_tensor.clone()
783
+
784
+ # GMM processing (unchanged)
785
+ for k in range(self.K):
786
+ B_mask = (self.B >= (k + 1)).to(device)
787
+ B_mask = B_mask & table_mask_tensor
788
+
789
+ mu_k = self.mu[:, :, k, :].to(device)
790
+ sigma_k = self.sigma[:, :, k, :, :].to(device)
791
+
792
+ delta = img_tensor - mu_k
793
+ delta = delta.unsqueeze(-1)
794
+ sigma_inv = torch.linalg.inv(sigma_k)
795
+ temp = torch.matmul(sigma_inv, delta)
796
+ dist_sq = torch.matmul(delta.transpose(-2, -1), temp).squeeze(-1).squeeze(-1)
797
+ dist = torch.sqrt(dist_sq + 1e-5)
798
+
799
+ match_mask = (dist < 7.0) & B_mask
800
+ detection_mask[match_mask] = False
801
+ img_tensor[match_mask] = mu_k[match_mask]
802
+
803
+ # Foreground detection
804
+ foreground_mask = detection_mask & (img_tensor.abs().sum(dim=-1) > 0) & table_mask_tensor
805
+ #------------------------------------------------------------Below line was replaced with region propfill code
806
+ # filled_mask = foreground_mask
807
+
808
+
809
+ # === REGION PROPFILL ENHANCEMENT ===
810
+ # Convert foreground mask to numpy for processing
811
+ foreground_np = foreground_mask.detach().cpu().numpy()
812
+ table_mask_np = table_mask_tensor.detach().cpu().numpy() if table_mask_tensor is not None else None
813
+ # Apply region propfill enhancement with hardcoded parameters
814
+ enhanced_mask, debug_info = self.region_propfill_enhancement(
815
+ foreground_np,table_mask=table_mask_np,
816
+ dilation_kernel_size=3, # Hardcoded: size of dilation kernel
817
+ dilation_iterations=1, # Hardcoded: connect nearby fragments
818
+ erosion_iterations=2, # Hardcoded: restore original size
819
+ fill_threshold=230, # Hardcoded: threshold for flood fill
820
+ min_contour_area=200 # Hardcoded: filter small noise
821
+ )
822
+
823
+ # Convert enhanced mask back to tensor
824
+ filled_mask = torch.from_numpy(enhanced_mask).bool().to(device)
825
+
826
+ # Optional: Print enhancement statistics
827
+ if np.any(enhanced_mask != foreground_np):
828
+ added_pixels = np.sum(enhanced_mask) - np.sum(foreground_np)
829
+ # print(f"🔧 Region propfill added {added_pixels} pixels to fill hollow regions")
830
+ #---------------------------------------------------------------------------------------------------------------------------------
831
+ # Heatmap accumulation
832
+ # pixelwise_alpha = alpha_start - (heatmap * (alpha_start - alpha_end))
833
+ # pixelwise_alpha = torch.clamp(pixelwise_alpha, min=alpha_end)
834
+
835
+ # heatmap = torch.where(
836
+ # filled_mask & table_mask_tensor,
837
+ # torch.clamp(heatmap + pixelwise_alpha, 0, 1),
838
+ # heatmap
839
+ # )
840
+
841
+ if heatmap is None:
842
+ heatmap = torch.zeros((working_H, working_W), dtype=torch.float32, device=device)
843
+
844
+ pixelwise_alpha = alpha_start - (heatmap * (alpha_start - alpha_end))
845
+ pixelwise_alpha = torch.clamp(pixelwise_alpha, min=alpha_end)
846
+
847
+ # === ACCUMULATION: Grow heatmap slowly where foreground detected ===
848
+ heatmap = torch.where(
849
+ filled_mask & table_mask_tensor,
850
+ torch.clamp(heatmap + pixelwise_alpha * 0.3, 0, 1), # 0.3 factor = SLOW growth
851
+ heatmap
852
+ )
853
+ if cleaning_mask is not None:
854
+ # Convert cleaning mask to tensor
855
+ cleaning_tensor = torch.from_numpy(cleaning_mask > 0).bool().to(device)
856
+
857
+ # Ensure dimensions match
858
+ if cleaning_tensor.shape != heatmap.shape:
859
+ # This shouldn't happen, but safety check
860
+ pass
861
+
862
+ # Calculate decay rate (slower for older/hotter areas)
863
+ decay_alpha = alpha_start - (heatmap * (alpha_start - alpha_end))
864
+ decay_alpha = torch.clamp(decay_alpha, min=alpha_end)
865
+
866
+ # Apply gradual decay where cleaning
867
+ heatmap = torch.where(
868
+ cleaning_tensor & table_mask_tensor,
869
+ torch.clamp(heatmap - decay_alpha * 0.8, 0, 1), # 0.8 = decay slightly faster than growth
870
+ heatmap
871
+ )
872
+ # === CRITICAL: Proper output resizing ===
873
+ heatmap_np = heatmap.detach().cpu().numpy()
874
+
875
+ if needs_resize:
876
+ # Resize results back to original dimensions
877
+ # Use high-quality interpolation for final output
878
+ result_img = cv.resize(img_tensor.detach().cpu().numpy(), (orig_W, orig_H),
879
+ interpolation=cv.INTER_LINEAR)
880
+
881
+ # For heatmap, use INTER_LINEAR to preserve smooth gradients
882
+ heatmap_np = cv.resize(heatmap_np, (orig_W, orig_H), interpolation=cv.INTER_LINEAR)
883
+
884
+ # Resize table mask back for final masking
885
+ if table_mask is not None:
886
+ table_mask_final = cv.resize(table_mask_tensor.detach().cpu().numpy().astype(np.uint8),
887
+ (orig_W, orig_H), interpolation=cv.INTER_NEAREST) > 0
888
+ heatmap_np = heatmap_np * table_mask_final
889
+
890
+ # Use original image for blending
891
+ result = img.copy()
892
+ else:
893
+ result_img = img_tensor.detach().cpu().numpy()
894
+ result = img.copy()
895
+
896
+ if table_mask is not None:
897
+ table_mask_np = table_mask_tensor.detach().cpu().numpy()
898
+ heatmap_np = heatmap_np * table_mask_np
899
+
900
+ # Visualization with proper blending
901
+ # heatmap_viz = cv.applyColorMap((heatmap_np * 255).astype(np.uint8), cv.COLORMAP_JET)
902
+ # significant_heat = (heatmap_np > 0.1)
903
+
904
+ # if np.any(significant_heat):
905
+ # img_region = result[significant_heat]
906
+ # heat_region = heatmap_viz[significant_heat]
907
+
908
+ # if img_region.size > 0 and heat_region.size > 0:
909
+ # blended = cv.addWeighted(img_region, 0.7, heat_region, 0.3, 0)
910
+ # result[significant_heat] = blended
911
+
912
+ # return result, heatmap_np
913
+ # === FIX: Ensure heatmap stays ONLY within table bounds ===
914
+ if table_mask is not None:
915
+ # Match dimensions
916
+ if table_mask.shape != heatmap_np.shape:
917
+ table_mask_resized = cv.resize(
918
+ table_mask.astype(np.uint8),
919
+ (heatmap_np.shape[1], heatmap_np.shape[0]),
920
+ interpolation=cv.INTER_NEAREST
921
+ )
922
+ table_mask_final = table_mask_resized > 0
923
+ else:
924
+ table_mask_final = table_mask > 0
925
+
926
+ # CRITICAL: Zero out heatmap completely outside table
927
+ heatmap_np = heatmap_np * table_mask_final.astype(np.float32)
928
+ else:
929
+ table_mask_final = np.ones(heatmap_np.shape, dtype=bool)
930
+
931
+ # Create visualization ONLY on table area (no blue background)
932
+ heatmap_colored = cv.applyColorMap(
933
+ (heatmap_np * 255).astype(np.uint8),
934
+ cv.COLORMAP_JET
935
+ )
936
+
937
+ # Apply transparency: only blend where heatmap > threshold AND inside table
938
+ significant_heat = (heatmap_np > 0.1) & table_mask_final
939
+
940
+ if np.any(significant_heat):
941
+ # Blend ONLY significant areas
942
+ result_blended = result.copy()
943
+ result_blended[significant_heat] = cv.addWeighted(
944
+ result[significant_heat], 0.7,
945
+ heatmap_colored[significant_heat], 0.3, 0
946
+ )
947
+ result = result_blended
948
+
949
+ return result, heatmap_np
app_s_a_LiveCam.py ADDED
@@ -0,0 +1,1157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import torch
3
+ import numpy as np
4
+ from collections import deque
5
+ from threading import Thread, Lock
6
+ from queue import Queue
7
+ import time
8
+ import logging
9
+ import os
10
+ from datetime import datetime
11
+ from PIL import Image
12
+ from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor
13
+ from fastapi import FastAPI, HTTPException, StreamingResponse
14
+ from fastapi.responses import FileResponse, StreamingResponse
15
+ import asyncio
16
+ import uvicorn
17
+ from pydantic import BaseModel
18
+ from typing import Optional
19
+ import requests
20
+ from datetime import datetime, timedelta
21
+
22
+ # ===== IMPORT THE DISCORD ALERT MANAGER =====
23
+ from send_discord import DiscordAlertManager
24
+
25
+ logging.basicConfig(level=logging.INFO)
26
+ logger = logging.getLogger(__name__)
27
+
28
+ # ==================== DATA MODELS ====================
29
+
30
+ class StreamStartRequest(BaseModel):
31
+ """Start streaming request."""
32
+ rtmp_input_url: str
33
+ camera_path: str # e.g., "models/cam1" - will auto-pick gmm_model.joblib and mask.png
34
+
35
+
36
+ class StreamStopRequest(BaseModel):
37
+ """Stop streaming request."""
38
+ stream_id: str
39
+
40
+
41
+ class StreamStatusResponse(BaseModel):
42
+ """Stream status response."""
43
+ stream_id: str
44
+ status: str
45
+ fps: float
46
+ buffered_frames: int
47
+ queue_size: int
48
+
49
+
50
+ # ==================== CIRCULAR BUFFER ====================
51
+
52
+ class CircularFrameBuffer:
53
+ """Fixed-size buffer for storing processed frames."""
54
+
55
+ def __init__(self, max_frames: int = 30):
56
+ self.max_frames = max_frames
57
+ self.frames = deque(maxlen=max_frames)
58
+ self.lock = Lock()
59
+ self.sequence_ids = deque(maxlen=max_frames)
60
+
61
+ def add_frame(self, frame: np.ndarray, seq_id: int) -> None:
62
+ """Add processed frame to buffer."""
63
+ with self.lock:
64
+ self.frames.append(frame.copy())
65
+ self.sequence_ids.append(seq_id)
66
+
67
+ def get_latest(self) -> tuple:
68
+ """Get most recent frame."""
69
+ with self.lock:
70
+ if len(self.frames) > 0:
71
+ return self.frames[-1].copy(), self.sequence_ids[-1]
72
+ return None, None
73
+
74
+ def clear(self) -> None:
75
+ """Clear buffer."""
76
+ with self.lock:
77
+ self.frames.clear()
78
+ self.sequence_ids.clear()
79
+
80
+
81
+ # ==================== LIVE MONITOR ====================
82
+
83
+ class LiveHygieneMonitor:
84
+ """Production-ready hygiene monitor for live streams."""
85
+
86
+ def __init__(self, segformer_path: str, max_buffer_frames: int = 30):
87
+ self.segformer_path = segformer_path
88
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
89
+
90
+ # Model loading
91
+ self.model = None
92
+ self.processor = None
93
+ self._load_segformer()
94
+
95
+ # GMM components
96
+ self.gmm_model = None
97
+ self.gmm_heatmap = None
98
+ self.table_mask = None
99
+
100
+ # Live streaming state
101
+ self.frame_buffer = CircularFrameBuffer(max_frames=max_buffer_frames)
102
+ self.input_queue = Queue(maxsize=5)
103
+ self.processing_thread = None
104
+ self.is_running = False
105
+
106
+ # Frame sequence tracking
107
+ self.frame_sequence = 0
108
+ self.frame_lock = Lock()
109
+
110
+ # State management
111
+ self.detection_frames_count = 0
112
+ self.no_detection_frames_count = 0
113
+ self.cleaning_active = False
114
+ self.cleaning_start_threshold = 4
115
+ self.cleaning_stop_threshold = 12
116
+
117
+ # Performance tracking
118
+ self.frame_times = deque(maxlen=30)
119
+ self.last_frame_time = time.time()
120
+
121
+ # Optimization flags
122
+ self.skip_segformer_every_n_frames = 2
123
+ self.segformer_skip_counter = 0
124
+ self.last_cloth_mask = None
125
+
126
+ # Visualization settings
127
+ self.show_cloth_detection = True
128
+ self.erasure_radius_factor = 0.2
129
+ self.gaussian_sigma_factor = 0.8
130
+
131
+ self.tracker = None
132
+ self.track_trajectories = {}
133
+ self.max_trajectory_length = 40
134
+ self.track_colors = {}
135
+
136
+ # Alert manager - ADD THIS
137
+ self.alert_manager = None
138
+ self.current_camera_name = "Default Camera"
139
+
140
+ logger.info(f"Live Monitor initialized on {self.device}")
141
+
142
+ def _load_segformer(self):
143
+ """Load SegFormer model."""
144
+ try:
145
+ self.model = SegformerForSemanticSegmentation.from_pretrained(self.segformer_path)
146
+ self.processor = SegformerImageProcessor(do_reduce_labels=False)
147
+ self.model.to(self.device)
148
+ self.model.eval()
149
+ logger.info(f"SegFormer loaded on {self.device}")
150
+ except Exception as e:
151
+ logger.error(f"Failed to load SegFormer: {e}")
152
+
153
+ def _init_tracker(self):
154
+ """Lazy-init tracker."""
155
+ if self.tracker is None:
156
+ from deep_sort_realtime.deepsort_tracker import DeepSort
157
+ self.tracker = DeepSort(
158
+ max_age=15,
159
+ n_init=2,
160
+ nms_max_overlap=0.7,
161
+ max_cosine_distance=0.4,
162
+ nn_budget=50,
163
+ embedder="mobilenet",
164
+ half=True,
165
+ embedder_gpu=torch.cuda.is_available()
166
+ )
167
+
168
+ def load_gmm_model(self, gmm_path: str) -> bool:
169
+ """Load GMM model."""
170
+ try:
171
+ from GMM import GMM
172
+ self.gmm_model = GMM.load_model(gmm_path)
173
+ if self.gmm_model.img_shape:
174
+ h, w = self.gmm_model.img_shape[:2]
175
+ self.gmm_heatmap = np.zeros((h, w), dtype=np.float32)
176
+ logger.info("GMM model loaded")
177
+ return True
178
+ except Exception as e:
179
+ logger.error(f"Failed to load GMM: {e}")
180
+ return False
181
+
182
+ def load_table_mask(self, mask_path: str) -> bool:
183
+ """Load table mask."""
184
+ try:
185
+ mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
186
+ self.table_mask = (mask > 128).astype(np.uint8)
187
+ logger.info(f"Table mask loaded: {mask.shape}")
188
+ return True
189
+ except Exception as e:
190
+ logger.error(f"Failed to load mask: {e}")
191
+ return False
192
+
193
+ def add_frame(self, frame: np.ndarray) -> None:
194
+ """Add incoming frame (non-blocking)."""
195
+ try:
196
+ self.input_queue.put_nowait(frame)
197
+ except:
198
+ pass
199
+
200
+ def start_processing(self) -> None:
201
+ """Start background processing."""
202
+ if self.is_running:
203
+ return
204
+ self.is_running = True
205
+ self.processing_thread = Thread(target=self._process_loop, daemon=True)
206
+ self.processing_thread.start()
207
+ logger.info("Processing thread started")
208
+
209
+ def stop_processing(self) -> None:
210
+ """Stop processing."""
211
+ self.is_running = False
212
+ if self.processing_thread:
213
+ self.processing_thread.join(timeout=5)
214
+ self.frame_buffer.clear()
215
+ logger.info("Processing stopped")
216
+
217
+ def _get_next_sequence_id(self) -> int:
218
+ """Thread-safe sequence ID."""
219
+ with self.frame_lock:
220
+ self.frame_sequence += 1
221
+ return self.frame_sequence
222
+
223
+ def _process_loop(self) -> None:
224
+ """Main processing loop."""
225
+ while self.is_running:
226
+ try:
227
+ frame = self.input_queue.get(timeout=1)
228
+ seq_id = self._get_next_sequence_id()
229
+
230
+ frame = self._resize_frame(frame, target_width=1024)
231
+ cloth_mask = self._detect_cloth_fast(frame)
232
+ cleaning_status = self._update_cleaning_status(cloth_mask)
233
+
234
+ tracks = None
235
+ if self.cleaning_active:
236
+ self._init_tracker()
237
+ tracks = self._track_cloth(frame, cloth_mask)
238
+
239
+ self._update_gmm_fast(frame, cloth_mask, tracks)
240
+ viz_frame = self._create_visualization(frame, cloth_mask, tracks, cleaning_status)
241
+ self.frame_buffer.add_frame(viz_frame, seq_id)
242
+
243
+ elapsed = time.time() - self.last_frame_time
244
+ self.frame_times.append(elapsed)
245
+ self.last_frame_time = time.time()
246
+
247
+ if seq_id % 30 == 0:
248
+ avg_time = np.mean(self.frame_times)
249
+ fps = 1.0 / avg_time if avg_time > 0 else 0
250
+ logger.info(f"Seq {seq_id} | {fps:.1f} FPS | {cleaning_status}")
251
+
252
+ except Exception as e:
253
+ logger.error(f"Processing error: {e}")
254
+ continue
255
+
256
+ def _resize_frame(self, frame: np.ndarray, target_width: int = 1024) -> np.ndarray:
257
+ """Resize frame."""
258
+ h, w = frame.shape[:2]
259
+ if w > target_width:
260
+ scale = target_width / w
261
+ new_h = int(h * scale)
262
+ return cv2.resize(frame, (target_width, new_h))
263
+ return frame
264
+
265
+ def _detect_cloth_fast(self, frame: np.ndarray) -> np.ndarray:
266
+ """Fast cloth detection with skipping."""
267
+ if self.model is None:
268
+ return np.zeros((frame.shape[0], frame.shape[1]), dtype=np.uint8)
269
+
270
+ self.segformer_skip_counter += 1
271
+ if self.segformer_skip_counter < self.skip_segformer_every_n_frames:
272
+ if self.last_cloth_mask is not None:
273
+ return self.last_cloth_mask
274
+ return np.zeros((frame.shape[0], frame.shape[1]), dtype=np.uint8)
275
+
276
+ self.segformer_skip_counter = 0
277
+
278
+ try:
279
+ height, width = frame.shape[:2]
280
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
281
+ pil_image = Image.fromarray(frame_rgb)
282
+
283
+ with torch.no_grad():
284
+ inputs = self.processor(images=pil_image, return_tensors="pt")
285
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
286
+ outputs = self.model(**inputs)
287
+ logits = outputs.logits
288
+
289
+ upsampled = torch.nn.functional.interpolate(
290
+ logits, size=(height, width), mode="bilinear", align_corners=False
291
+ )
292
+
293
+ cloth_mask = (upsampled.argmax(dim=1)[0].cpu().numpy() == 1).astype(np.uint8)
294
+
295
+ if self.table_mask is not None:
296
+ if self.table_mask.shape != cloth_mask.shape:
297
+ table_resized = cv2.resize(self.table_mask, (width, height))
298
+ else:
299
+ table_resized = self.table_mask
300
+ cloth_mask = cloth_mask * table_resized
301
+
302
+ self.last_cloth_mask = cloth_mask
303
+ return cloth_mask
304
+
305
+ except Exception as e:
306
+ logger.error(f"Cloth detection error: {e}")
307
+ return np.zeros((frame.shape[0], frame.shape[1]), dtype=np.uint8)
308
+
309
+ def _track_cloth(self, frame: np.ndarray, cloth_mask: np.ndarray) -> list:
310
+ """Fast tracking."""
311
+ if self.tracker is None:
312
+ return []
313
+
314
+ try:
315
+ contours, _ = cv2.findContours(cloth_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
316
+ detections = []
317
+
318
+ for contour in contours:
319
+ area = cv2.contourArea(contour)
320
+ if area < 150:
321
+ continue
322
+ x, y, w, h = cv2.boundingRect(contour)
323
+ if w > 0 and h > 0:
324
+ detections.append(([x, y, w, h], 0.95, 'cloth'))
325
+
326
+ if not detections:
327
+ return []
328
+
329
+ tracks = self.tracker.update_tracks(detections, frame=frame)
330
+
331
+ height, width = frame.shape[:2]
332
+ for track in tracks:
333
+ if not track.is_confirmed():
334
+ continue
335
+
336
+ track_id = track.track_id
337
+ bbox = track.to_ltrb()
338
+ cx = int((bbox[0] + bbox[2]) / 2)
339
+ cy = int((bbox[1] + bbox[3]) / 2)
340
+
341
+ if 0 <= cx < width and 0 <= cy < height:
342
+ if track_id not in self.track_trajectories:
343
+ self.track_trajectories[track_id] = deque(maxlen=self.max_trajectory_length)
344
+ self.track_colors[track_id] = (255, 255, 0)
345
+ self.track_trajectories[track_id].append((cx, cy))
346
+
347
+ active_ids = {track.track_id for track in tracks if track.is_confirmed()}
348
+ dead_ids = set(self.track_trajectories.keys()) - active_ids
349
+ for dead_id in dead_ids:
350
+ self.track_trajectories.pop(dead_id, None)
351
+ self.track_colors.pop(dead_id, None)
352
+
353
+ return tracks
354
+
355
+ except Exception as e:
356
+ logger.error(f"Tracking error: {e}")
357
+ return []
358
+
359
+ def _update_gmm_fast(self, frame: np.ndarray, cloth_mask: np.ndarray, tracks: list) -> None:
360
+ """Lightweight GMM update."""
361
+ if self.gmm_model is None:
362
+ return
363
+
364
+ try:
365
+ height, width = frame.shape[:2]
366
+ table_mask = None
367
+ if self.table_mask is not None:
368
+ if self.table_mask.shape != (height, width):
369
+ table_mask = cv2.resize(self.table_mask, (width, height))
370
+ else:
371
+ table_mask = self.table_mask
372
+
373
+ _, self.gmm_heatmap = self.gmm_model.infer(
374
+ frame, heatmap=self.gmm_heatmap,
375
+ alpha_start=0.008, alpha_end=0.0004,
376
+ table_mask=table_mask
377
+ )
378
+
379
+ if self.cleaning_active and tracks:
380
+ for track in tracks:
381
+ if not track.is_confirmed():
382
+ continue
383
+
384
+ track_id = track.track_id
385
+ if track_id not in self.track_trajectories:
386
+ continue
387
+
388
+ trajectory = list(self.track_trajectories[track_id])
389
+ if len(trajectory) < 2:
390
+ continue
391
+
392
+ bbox = track.to_ltrb()
393
+ w = bbox[2] - bbox[0]
394
+ h = bbox[3] - bbox[1]
395
+
396
+ radius = int(min(w, h) * self.erasure_radius_factor)
397
+ radius = max(radius, 12)
398
+
399
+ if radius <= 0 or w <= 0 or h <= 0:
400
+ continue
401
+
402
+ for i in range(len(trajectory) - 1):
403
+ self._erase_at_point(trajectory[i], radius, table_mask)
404
+
405
+ except Exception as e:
406
+ logger.error(f"GMM update error: {e}")
407
+
408
+ def _erase_at_point(self, point: tuple, radius: int, table_mask: np.ndarray) -> None:
409
+ """Fast point-based erasure."""
410
+ if self.gmm_heatmap is None or radius <= 0:
411
+ return
412
+
413
+ x, y = point
414
+ height, width = self.gmm_heatmap.shape
415
+
416
+ y_min = max(0, y - radius)
417
+ y_max = min(height, y + radius)
418
+ x_min = max(0, x - radius)
419
+ x_max = min(width, x + radius)
420
+
421
+ if y_min >= y_max or x_min >= x_max:
422
+ return
423
+
424
+ y_indices, x_indices = np.ogrid[y_min:y_max, x_min:x_max]
425
+ distance_sq = (x_indices - x)**2 + (y_indices - y)**2
426
+
427
+ gaussian = np.exp(-distance_sq / (2 * (radius * self.gaussian_sigma_factor)**2))
428
+
429
+ if table_mask is not None:
430
+ gaussian = gaussian * table_mask[y_min:y_max, x_min:x_max]
431
+
432
+ decay = 0.025 * gaussian
433
+ self.gmm_heatmap[y_min:y_max, x_min:x_max] = np.maximum(
434
+ 0, self.gmm_heatmap[y_min:y_max, x_min:x_max] - decay
435
+ )
436
+
437
+ def _update_cleaning_status(self, cloth_mask: np.ndarray) -> str:
438
+ """Update cleaning status."""
439
+ has_cloth = np.sum(cloth_mask) > 100
440
+
441
+ if has_cloth:
442
+ self.detection_frames_count += 1
443
+ self.no_detection_frames_count = 0
444
+ else:
445
+ self.no_detection_frames_count += 1
446
+ self.detection_frames_count = 0
447
+
448
+ if not self.cleaning_active and self.detection_frames_count >= self.cleaning_start_threshold:
449
+ self.cleaning_active = True
450
+ return "CLEANING STARTED"
451
+ elif self.cleaning_active and self.no_detection_frames_count >= self.cleaning_stop_threshold:
452
+ self.cleaning_active = False
453
+ return "CLEANING STOPPED"
454
+
455
+ return "CLEANING ACTIVE" if self.cleaning_active else "NO CLEANING"
456
+
457
+ def _create_visualization(self, frame: np.ndarray, cloth_mask: np.ndarray,
458
+ tracks: list, cleaning_status: str) -> np.ndarray:
459
+ """Create fast visualization."""
460
+ result = frame.copy()
461
+
462
+ if np.sum(cloth_mask) > 0:
463
+ overlay = result.copy()
464
+ cloth_pixels = cloth_mask > 0
465
+ overlay[cloth_pixels] = [0, 255, 0]
466
+ result[cloth_pixels] = cv2.addWeighted(
467
+ frame[cloth_pixels], 0.7, overlay[cloth_pixels], 0.3, 0
468
+ )
469
+
470
+ if self.gmm_heatmap is not None and self.gmm_heatmap.max() > 0:
471
+ height, width = result.shape[:2]
472
+ heatmap_resized = cv2.resize(self.gmm_heatmap, (width, height))
473
+ heatmap_colored = cv2.applyColorMap(
474
+ (heatmap_resized * 255).astype(np.uint8), cv2.COLORMAP_JET
475
+ )
476
+ significant = heatmap_resized > 0.1
477
+ result[significant] = cv2.addWeighted(
478
+ frame[significant], 0.6, heatmap_colored[significant], 0.4, 0
479
+ )
480
+
481
+ if tracks:
482
+ for track in tracks:
483
+ if track.is_confirmed():
484
+ bbox = track.to_ltrb()
485
+ cx, cy = int((bbox[0] + bbox[2])/2), int((bbox[1] + bbox[3])/2)
486
+ cv2.circle(result, (cx, cy), 4, (0, 0, 255), -1)
487
+
488
+ status_color = (0, 255, 0) if "ACTIVE" in cleaning_status else (150, 150, 150)
489
+ cv2.putText(result, cleaning_status, (20, 40),
490
+ cv2.FONT_HERSHEY_SIMPLEX, 0.7, status_color, 2)
491
+
492
+ return result
493
+
494
+ def get_latest_frame(self) -> np.ndarray:
495
+ """Get latest processed frame."""
496
+ frame, _ = self.frame_buffer.get_latest()
497
+ return frame
498
+
499
+ def get_stats(self) -> dict:
500
+ """Get stats."""
501
+ with self.frame_buffer.lock:
502
+ avg_time = np.mean(self.frame_times) if len(self.frame_times) > 0 else 0.033
503
+ fps = 1.0 / avg_time if avg_time > 0 else 0
504
+ return {
505
+ "buffered_frames": len(self.frame_buffer.frames),
506
+ "avg_fps": fps,
507
+ "queue_size": self.input_queue.qsize(),
508
+ "is_running": self.is_running
509
+ }
510
+
511
+
512
+ # ==================== FASTAPI APP ====================
513
+
514
+ app = FastAPI(title="Hygiene Monitor Live Stream", version="1.0.0")
515
+
516
+ # Active streams: {stream_id: {"monitor": LiveHygieneMonitor, "cap": VideoCapture, "thread": Thread}}
517
+ active_streams = {}
518
+ streams_lock = Lock()
519
+
520
+
521
+ def _get_model_files(camera_path: str) -> tuple:
522
+ """Extract GMM and mask paths from camera directory."""
523
+ if not os.path.isdir(camera_path):
524
+ raise ValueError(f"Camera path not found: {camera_path}")
525
+
526
+ gmm_path = os.path.join(camera_path, "gmm_model.joblib")
527
+ mask_path = os.path.join(camera_path, "mask.png")
528
+
529
+ if not os.path.exists(gmm_path):
530
+ raise ValueError(f"GMM model not found: {gmm_path}")
531
+ if not os.path.exists(mask_path):
532
+ raise ValueError(f"Mask not found: {mask_path}")
533
+
534
+ return gmm_path, mask_path
535
+
536
+
537
+ def _stream_worker(stream_id: str, rtmp_url: str, gmm_path: str, mask_path: str):
538
+ """Background worker for streaming."""
539
+ try:
540
+ monitor = LiveHygieneMonitor(
541
+ segformer_path="models/segformer_model",
542
+ max_buffer_frames=30
543
+ )
544
+
545
+ if not monitor.load_gmm_model(gmm_path):
546
+ logger.error(f"[{stream_id}] Failed to load GMM model")
547
+ return
548
+
549
+ if not monitor.load_table_mask(mask_path):
550
+ logger.error(f"[{stream_id}] Failed to load mask")
551
+ return
552
+
553
+ # === INITIALIZE ALERT MANAGER - ADD THIS ===
554
+ webhook_url = os.getenv("DISCORD_WEBHOOK_URL") # From environment
555
+ if webhook_url:
556
+ monitor.alert_manager = DiscordAlertManager(webhook_url=webhook_url)
557
+ monitor.current_camera_name = stream_id # Or pass from request
558
+ logger.info(f"[{stream_id}] Alert manager initialized")
559
+
560
+ monitor.start_processing()
561
+
562
+ cap = cv2.VideoCapture(rtmp_url)
563
+ if not cap.isOpened():
564
+ logger.error(f"[{stream_id}] Failed to connect to RTMP: {rtmp_url}")
565
+ monitor.stop_processing()
566
+ return
567
+
568
+ # Update active stream
569
+ with streams_lock:
570
+ if stream_id in active_streams:
571
+ active_streams[stream_id]["monitor"] = monitor
572
+ active_streams[stream_id]["cap"] = cap
573
+ active_streams[stream_id]["connected"] = True
574
+
575
+ frame_count = 0
576
+ logger.info(f"[{stream_id}] Connected to {rtmp_url}")
577
+
578
+ while True:
579
+ with streams_lock:
580
+ if stream_id not in active_streams or not active_streams[stream_id]["running"]:
581
+ break
582
+
583
+ ret, frame = cap.read()
584
+ if not ret:
585
+ logger.warning(f"[{stream_id}] RTMP connection lost, reconnecting...")
586
+ cap.release()
587
+ time.sleep(2)
588
+ cap = cv2.VideoCapture(rtmp_url)
589
+ continue
590
+
591
+ monitor.add_frame(frame)
592
+ frame_count += 1
593
+
594
+ if frame_count % 100 == 0:
595
+ stats = monitor.get_stats()
596
+ logger.info(f"[{stream_id}] Frames: {frame_count}, FPS: {stats['avg_fps']:.1f}")
597
+
598
+ except Exception as e:
599
+ logger.error(f"[{stream_id}] Stream error: {e}")
600
+
601
+ finally:
602
+ with streams_lock:
603
+ if stream_id in active_streams:
604
+ if active_streams[stream_id]["cap"]:
605
+ active_streams[stream_id]["cap"].release()
606
+ if active_streams[stream_id]["monitor"]:
607
+ active_streams[stream_id]["monitor"].stop_processing()
608
+ active_streams[stream_id]["connected"] = False
609
+
610
+ logger.info(f"[{stream_id}] Stream closed")
611
+
612
+
613
+ # ==================== ENDPOINTS ====================
614
+
615
+ @app.post("/stream/start")
616
+ async def start_stream(request: StreamStartRequest):
617
+ """Start a new live stream."""
618
+ stream_id = f"stream_{int(time.time() * 1000)}"
619
+
620
+ try:
621
+ # Extract model files from camera path
622
+ gmm_path, mask_path = _get_model_files(request.camera_path)
623
+
624
+ # Create stream entry
625
+ with streams_lock:
626
+ active_streams[stream_id] = {
627
+ "running": True,
628
+ "connected": False,
629
+ "monitor": None,
630
+ "cap": None,
631
+ "thread": None,
632
+ "camera_path": request.camera_path
633
+ }
634
+
635
+ # Start background worker thread
636
+ thread = Thread(
637
+ target=_stream_worker,
638
+ args=(stream_id, request.rtmp_input_url, gmm_path, mask_path),
639
+ daemon=True
640
+ )
641
+ thread.start()
642
+
643
+ with streams_lock:
644
+ active_streams[stream_id]["thread"] = thread
645
+
646
+ logger.info(f"Stream {stream_id} started")
647
+ return {
648
+ "stream_id": stream_id,
649
+ "status": "starting",
650
+ "message": f"Stream {stream_id} is starting, will connect to {request.rtmp_input_url}"
651
+ }
652
+
653
+ except Exception as e:
654
+ logger.error(f"Failed to start stream: {e}")
655
+ raise HTTPException(status_code=400, detail=str(e))
656
+
657
+
658
+ @app.post("/stream/stop")
659
+ async def stop_stream(request: StreamStopRequest):
660
+ """Stop a live stream."""
661
+ stream_id = request.stream_id
662
+
663
+ with streams_lock:
664
+ if stream_id not in active_streams:
665
+ raise HTTPException(status_code=404, detail=f"Stream {stream_id} not found")
666
+
667
+ active_streams[stream_id]["running"] = False
668
+
669
+ logger.info(f"Stream {stream_id} stop requested")
670
+ return {"stream_id": stream_id, "status": "stopping"}
671
+
672
+
673
+ @app.get("/stream/status/{stream_id}")
674
+ async def get_stream_status(stream_id: str):
675
+ """Get stream status."""
676
+ with streams_lock:
677
+ if stream_id not in active_streams:
678
+ raise HTTPException(status_code=404, detail=f"Stream {stream_id} not found")
679
+
680
+ stream_data = active_streams[stream_id]
681
+ monitor = stream_data["monitor"]
682
+
683
+ stats = monitor.get_stats() if monitor else {}
684
+
685
+ return {
686
+ "stream_id": stream_id,
687
+ "connected": stream_data["connected"],
688
+ "running": stream_data["running"],
689
+ "camera_path": stream_data["camera_path"],
690
+ "fps": stats.get("avg_fps", 0),
691
+ "buffered_frames": stats.get("buffered_frames", 0),
692
+ "queue_size": stats.get("queue_size", 0)
693
+ }
694
+
695
+
696
+ @app.get("/stream/video/{stream_id}")
697
+ async def stream_video(stream_id: str):
698
+ """Stream video frames via MJPEG."""
699
+ with streams_lock:
700
+ if stream_id not in active_streams:
701
+ raise HTTPException(status_code=404, detail=f"Stream {stream_id} not found")
702
+
703
+ monitor = active_streams[stream_id]["monitor"]
704
+
705
+ if not monitor:
706
+ raise HTTPException(status_code=503, detail="Monitor not ready")
707
+
708
+ async def frame_generator():
709
+ while True:
710
+ with streams_lock:
711
+ if stream_id not in active_streams or not active_streams[stream_id]["running"]:
712
+ break
713
+
714
+ frame = monitor.get_latest_frame()
715
+ if frame is not None:
716
+ _, buffer = cv2.imencode('.jpg', frame, [cv2.IMWRITE_JPEG_QUALITY, 80])
717
+ yield (b'--frame\r\n'
718
+ b'Content-Type: image/jpeg\r\n'
719
+ b'Content-Length: ' + str(len(buffer)).encode() + b'\r\n\r\n'
720
+ + buffer.tobytes() + b'\r\n')
721
+ else:
722
+ await asyncio.sleep(0.01)
723
+
724
+ return StreamingResponse(
725
+ frame_generator(),
726
+ media_type="multipart/x-mixed-replace; boundary=frame"
727
+ )
728
+
729
+
730
+ @app.get("/streams")
731
+ async def list_streams():
732
+ """List all active streams."""
733
+ with streams_lock:
734
+ streams_list = []
735
+ for stream_id, data in active_streams.items():
736
+ monitor = data["monitor"]
737
+ stats = monitor.get_stats() if monitor else {}
738
+
739
+ streams_list.append({
740
+ "stream_id": stream_id,
741
+ "connected": data["connected"],
742
+ "running": data["running"],
743
+ "camera_path": data["camera_path"],
744
+ "fps": stats.get("avg_fps", 0),
745
+ "buffered_frames": stats.get("buffered_frames", 0)
746
+ })
747
+
748
+ return {"total_streams": len(streams_list), "streams": streams_list}
749
+
750
+
751
+ @app.post("/stream/restart/{stream_id}")
752
+ async def restart_stream(stream_id: str):
753
+ """Restart a stream."""
754
+ with streams_lock:
755
+ if stream_id not in active_streams:
756
+ raise HTTPException(status_code=404, detail=f"Stream {stream_id} not found")
757
+
758
+ active_streams[stream_id]["running"] = False
759
+
760
+ await asyncio.sleep(2)
761
+
762
+ with streams_lock:
763
+ data = active_streams[stream_id]
764
+ data["running"] = True
765
+
766
+ return {"stream_id": stream_id, "status": "restarting"}
767
+
768
+ @app.post("/camera/extract_frame")
769
+ async def extract_frame_from_rtmp(request: dict):
770
+ """
771
+ Extract a single frame from RTMP stream for corner selection.
772
+
773
+ Request body:
774
+ {
775
+ "rtmp_url": "rtmp://192.168.1.100:1935/live/kitchen",
776
+ "camera_name": "kitchen"
777
+ }
778
+
779
+ Returns:
780
+ {
781
+ "success": true,
782
+ "frame_base64": "base64_encoded_image",
783
+ "frame_dimensions": {"width": 1920, "height": 1080}
784
+ }
785
+ """
786
+ try:
787
+ rtmp_url = request.get("rtmp_url")
788
+ camera_name = request.get("camera_name")
789
+
790
+ if not rtmp_url or not camera_name:
791
+ raise HTTPException(status_code=400, detail="Missing rtmp_url or camera_name")
792
+
793
+ # Connect to RTMP stream
794
+ cap = cv2.VideoCapture(rtmp_url)
795
+ if not cap.isOpened():
796
+ raise HTTPException(status_code=400, detail=f"Failed to connect to RTMP: {rtmp_url}")
797
+
798
+ # Read first frame
799
+ ret, frame = cap.read()
800
+ cap.release()
801
+
802
+ if not ret:
803
+ raise HTTPException(status_code=400, detail="Failed to read frame from RTMP stream")
804
+ import base64
805
+ # Convert frame to base64 for frontend display
806
+ _, buffer = cv2.imencode('.jpg', frame, [cv2.IMWRITE_JPEG_QUALITY, 95])
807
+ frame_base64 = base64.b64encode(buffer).decode('utf-8')
808
+
809
+ # Store frame temporarily for training (optional - could store in memory cache)
810
+ temp_dir = "temp_frames"
811
+ os.makedirs(temp_dir, exist_ok=True)
812
+ temp_frame_path = os.path.join(temp_dir, f"{camera_name}_reference.jpg")
813
+ cv2.imwrite(temp_frame_path, frame)
814
+
815
+ return {
816
+ "success": True,
817
+ "frame_base64": frame_base64,
818
+ "frame_dimensions": {
819
+ "width": frame.shape[1],
820
+ "height": frame.shape[0]
821
+ },
822
+ "temp_frame_path": temp_frame_path
823
+ }
824
+
825
+ except Exception as e:
826
+ logger.error(f"Extract frame error: {e}")
827
+ raise HTTPException(status_code=500, detail=str(e))
828
+
829
+
830
+ @app.post("/camera/train_gmm")
831
+ async def train_gmm_from_rtmp(request: dict):
832
+ """
833
+ Train GMM model from RTMP stream using N corner points (minimum 4).
834
+
835
+ Request body:
836
+ {
837
+ "rtmp_url": "rtmp://192.168.1.100:1935/live/kitchen",
838
+ "camera_name": "kitchen",
839
+ "corner_points": [
840
+ {"x": 100, "y": 50},
841
+ {"x": 400, "y": 45},
842
+ {"x": 700, "y": 55},
843
+ {"x": 800, "y": 60},
844
+ {"x": 850, "y": 300},
845
+ {"x": 850, "y": 600},
846
+ {"x": 400, "y": 620},
847
+ {"x": 50, "y": 580},
848
+ {"x": 45, "y": 300}
849
+ ], // Can be 4+ points for curved tables
850
+ "max_frames": 250,
851
+ "use_perspective_warp": false // NEW: Set false for non-rectangular tables
852
+ }
853
+ """
854
+ try:
855
+ rtmp_url = request.get("rtmp_url")
856
+ camera_name = request.get("camera_name")
857
+ corner_points = request.get("corner_points")
858
+ max_frames = request.get("max_frames", 250)
859
+ use_perspective_warp = request.get("use_perspective_warp", False) # NEW
860
+
861
+ # Validation
862
+ if not rtmp_url or not camera_name or not corner_points:
863
+ raise HTTPException(status_code=400, detail="Missing required parameters")
864
+
865
+ if len(corner_points) < 4:
866
+ raise HTTPException(status_code=400, detail="Minimum 4 corner points required")
867
+
868
+ logger.info(f"Starting GMM training for camera: {camera_name} with {len(corner_points)} points")
869
+
870
+ # ===== STEP 1: Connect to RTMP and capture frames =====
871
+ cap = cv2.VideoCapture(rtmp_url)
872
+ if not cap.isOpened():
873
+ raise HTTPException(status_code=400, detail=f"Failed to connect to RTMP: {rtmp_url}")
874
+
875
+ ret, first_frame = cap.read()
876
+ if not ret:
877
+ cap.release()
878
+ raise HTTPException(status_code=400, detail="Failed to read from RTMP stream")
879
+
880
+ h, w = first_frame.shape[:2]
881
+
882
+ # ===== STEP 2: Create polygon mask from N points =====
883
+ pts_polygon = np.array([
884
+ [point['x'], point['y']] for point in corner_points
885
+ ], dtype=np.int32)
886
+
887
+ # Create binary mask for the table area
888
+ table_mask = np.zeros((h, w), dtype=np.uint8)
889
+ cv2.fillPoly(table_mask, [pts_polygon], 255)
890
+
891
+ # ===== STEP 3: Decide transformation strategy =====
892
+ import tempfile
893
+ temp_dir = tempfile.mkdtemp()
894
+ frame_count = 0
895
+
896
+ if use_perspective_warp and len(corner_points) == 4:
897
+ # ===== STRATEGY A: Perspective warp (rectangular tables only) =====
898
+ logger.info("Using perspective warp for rectangular table")
899
+
900
+ pts_src = np.array([
901
+ [corner_points[0]['x'], corner_points[0]['y']],
902
+ [corner_points[1]['x'], corner_points[1]['y']],
903
+ [corner_points[2]['x'], corner_points[2]['y']],
904
+ [corner_points[3]['x'], corner_points[3]['y']]
905
+ ], dtype=np.float32)
906
+
907
+ pts_dst = np.array([
908
+ [0, 0], [w, 0], [w, h], [0, h]
909
+ ], dtype=np.float32)
910
+
911
+ matrix = cv2.getPerspectiveTransform(pts_src, pts_dst)
912
+
913
+ # Capture and warp frames
914
+ cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
915
+ while frame_count < max_frames:
916
+ ret, frame = cap.read()
917
+ if not ret:
918
+ break
919
+
920
+ warped = cv2.warpPerspective(frame, matrix, (w, h))
921
+ frame_path = os.path.join(temp_dir, f'b{frame_count:05d}.png')
922
+ cv2.imwrite(frame_path, warped)
923
+ frame_count += 1
924
+
925
+ if frame_count % 50 == 0:
926
+ logger.info(f"Captured {frame_count}/{max_frames} frames")
927
+
928
+ # For warped images, mask should be full frame (already aligned)
929
+ final_mask = np.ones((h, w), dtype=np.uint8) * 255
930
+
931
+ else:
932
+ # ===== STRATEGY B: Direct masking (curved/complex tables) =====
933
+ logger.info(f"Using direct masking for {len(corner_points)}-point polygon (curved table)")
934
+
935
+ # Capture frames WITHOUT warping, apply mask during inference
936
+ cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
937
+ while frame_count < max_frames:
938
+ ret, frame = cap.read()
939
+ if not ret:
940
+ break
941
+
942
+ # Apply mask to frame (zero out outside table area)
943
+ masked_frame = cv2.bitwise_and(frame, frame, mask=table_mask)
944
+
945
+ frame_path = os.path.join(temp_dir, f'b{frame_count:05d}.png')
946
+ cv2.imwrite(frame_path, masked_frame)
947
+ frame_count += 1
948
+
949
+ if frame_count % 50 == 0:
950
+ logger.info(f"Captured {frame_count}/{max_frames} frames")
951
+
952
+ # Use original polygon mask
953
+ final_mask = table_mask
954
+
955
+ cap.release()
956
+
957
+ if frame_count == 0:
958
+ raise HTTPException(status_code=400, detail="No frames captured")
959
+
960
+ logger.info(f"Captured {frame_count} frames, starting GMM training...")
961
+
962
+ # ===== STEP 4: Train GMM =====
963
+ from GMM import GMM
964
+ gmm = GMM(temp_dir, frame_count, alpha=0.05)
965
+ gmm.train(K=4)
966
+ logger.info("GMM training complete")
967
+
968
+ # ===== STEP 5: Save artifacts =====
969
+ camera_path = os.path.join("models", camera_name)
970
+ os.makedirs(camera_path, exist_ok=True)
971
+
972
+ # 1. Save GMM model
973
+ gmm_path = os.path.join(camera_path, "gmm_model.joblib")
974
+ gmm.save_model(gmm_path)
975
+
976
+ # 2. Save mask (polygon-based, not rectangular)
977
+ mask_path = os.path.join(camera_path, "mask.png")
978
+ cv2.imwrite(mask_path, final_mask)
979
+ logger.info(f"Saved {len(corner_points)}-point polygon mask to {mask_path}")
980
+
981
+ # 3. Create thumbnail with polygon overlay
982
+ thumb_frame = first_frame.copy()
983
+
984
+ # Draw filled polygon with transparency
985
+ overlay = thumb_frame.copy()
986
+ cv2.fillPoly(overlay, [pts_polygon], (0, 255, 0))
987
+ cv2.addWeighted(thumb_frame, 0.7, overlay, 0.3, 0, thumb_frame)
988
+
989
+ # Draw polygon border
990
+ cv2.polylines(thumb_frame, [pts_polygon], True, (0, 255, 0), 3)
991
+
992
+ # Draw corner points with numbers
993
+ colors = [
994
+ (255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0),
995
+ (255, 0, 255), (0, 255, 255), (128, 0, 128), (255, 128, 0)
996
+ ]
997
+
998
+ for i, point in enumerate(corner_points):
999
+ x, y = point['x'], point['y']
1000
+ color = colors[i % len(colors)]
1001
+
1002
+ cv2.circle(thumb_frame, (x, y), 8, color, -1)
1003
+ cv2.circle(thumb_frame, (x, y), 10, (255, 255, 255), 2)
1004
+
1005
+ # Point number
1006
+ cv2.putText(thumb_frame, str(i+1), (x+15, y),
1007
+ cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2)
1008
+
1009
+ # Camera name label
1010
+ cv2.putText(thumb_frame, camera_name, (30, 50),
1011
+ cv2.FONT_HERSHEY_DUPLEX, 1.5, (255, 255, 255), 3)
1012
+ cv2.putText(thumb_frame, camera_name, (30, 50),
1013
+ cv2.FONT_HERSHEY_DUPLEX, 1.5, (0, 255, 0), 2)
1014
+
1015
+ # Add point count indicator
1016
+ cv2.putText(thumb_frame, f"{len(corner_points)} points", (30, 90),
1017
+ cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 2)
1018
+
1019
+ thumb_path = os.path.join(camera_path, "thumb.png")
1020
+ cv2.imwrite(thumb_path, thumb_frame)
1021
+
1022
+ # 4. Save polygon metadata (NEW - for reconstruction)
1023
+ metadata = {
1024
+ "camera_name": camera_name,
1025
+ "num_points": len(corner_points),
1026
+ "corner_points": corner_points,
1027
+ "frame_dimensions": {"width": w, "height": h},
1028
+ "use_perspective_warp": use_perspective_warp,
1029
+ "training_date": datetime.now().isoformat()
1030
+ }
1031
+
1032
+ import json
1033
+ metadata_path = os.path.join(camera_path, "metadata.json")
1034
+ with open(metadata_path, 'w') as f:
1035
+ json.dump(metadata, f, indent=2)
1036
+
1037
+ logger.info(f"Saved metadata to {metadata_path}")
1038
+
1039
+ # Cleanup
1040
+ import shutil
1041
+ shutil.rmtree(temp_dir)
1042
+
1043
+ logger.info(f"✅ Camera '{camera_name}' training complete with {len(corner_points)}-point polygon!")
1044
+
1045
+ return {
1046
+ "success": True,
1047
+ "camera_name": camera_name,
1048
+ "camera_path": camera_path,
1049
+ "frames_captured": frame_count,
1050
+ "polygon_points": len(corner_points),
1051
+ "use_perspective_warp": use_perspective_warp,
1052
+ "model_files": {
1053
+ "gmm_model": gmm_path,
1054
+ "mask": mask_path,
1055
+ "thumbnail": thumb_path,
1056
+ "metadata": metadata_path
1057
+ }
1058
+ }
1059
+
1060
+ except Exception as e:
1061
+ logger.error(f"GMM training error: {e}")
1062
+ import traceback
1063
+ logger.error(traceback.format_exc())
1064
+ raise HTTPException(status_code=500, detail=str(e))
1065
+
1066
+
1067
+ @app.get("/cameras")
1068
+ async def list_cameras():
1069
+ """
1070
+ List all trained cameras with their metadata.
1071
+
1072
+ Returns:
1073
+ {
1074
+ "cameras": [
1075
+ {
1076
+ "name": "kitchen",
1077
+ "path": "models/kitchen",
1078
+ "thumbnail": "models/kitchen/thumb.png",
1079
+ "has_gmm_model": true,
1080
+ "has_mask": true
1081
+ }
1082
+ ]
1083
+ }
1084
+ """
1085
+ try:
1086
+ cameras = []
1087
+ models_dir = "models"
1088
+
1089
+ if not os.path.exists(models_dir):
1090
+ return {"cameras": []}
1091
+
1092
+ for camera_name in os.listdir(models_dir):
1093
+ camera_path = os.path.join(models_dir, camera_name)
1094
+
1095
+ if not os.path.isdir(camera_path):
1096
+ continue
1097
+
1098
+ gmm_path = os.path.join(camera_path, "gmm_model.joblib")
1099
+ mask_path = os.path.join(camera_path, "mask.png")
1100
+ thumb_path = os.path.join(camera_path, "thumb.png")
1101
+
1102
+ cameras.append({
1103
+ "name": camera_name,
1104
+ "path": camera_path,
1105
+ "thumbnail": thumb_path if os.path.exists(thumb_path) else None,
1106
+ "has_gmm_model": os.path.exists(gmm_path),
1107
+ "has_mask": os.path.exists(mask_path)
1108
+ })
1109
+
1110
+ return {"cameras": cameras}
1111
+
1112
+ except Exception as e:
1113
+ logger.error(f"List cameras error: {e}")
1114
+ raise HTTPException(status_code=500, detail=str(e))
1115
+
1116
+
1117
+ @app.delete("/camera/{camera_name}")
1118
+ async def delete_camera(camera_name: str):
1119
+ """
1120
+ Delete a trained camera and all its files.
1121
+ """
1122
+ try:
1123
+ camera_path = os.path.join("models", camera_name)
1124
+
1125
+ if not os.path.exists(camera_path):
1126
+ raise HTTPException(status_code=404, detail=f"Camera '{camera_name}' not found")
1127
+
1128
+ import shutil
1129
+ shutil.rmtree(camera_path)
1130
+
1131
+ logger.info(f"Deleted camera: {camera_name}")
1132
+
1133
+ return {
1134
+ "success": True,
1135
+ "message": f"Camera '{camera_name}' deleted successfully"
1136
+ }
1137
+
1138
+ except Exception as e:
1139
+ logger.error(f"Delete camera error: {e}")
1140
+ raise HTTPException(status_code=500, detail=str(e))
1141
+
1142
+
1143
+ @app.get("/health")
1144
+ async def health_check():
1145
+ """Health check endpoint."""
1146
+ with streams_lock:
1147
+ stream_count = len(active_streams)
1148
+
1149
+ return {
1150
+ "status": "healthy",
1151
+ "active_streams": stream_count,
1152
+ "timestamp": datetime.now().isoformat()
1153
+ }
1154
+
1155
+
1156
+ if __name__ == "__main__":
1157
+ uvicorn.run(app, host="0.0.0.0", port=8000)
requirements.txt ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ opencv-python
2
+ opencv-contrib-python
3
+ joblib
4
+ scikit-learn
5
+ numpy==1.24.3
6
+ #torchvision==0.15.2
7
+ ultralytics
8
+ gradio
9
+ Pillow
10
+ matplotlib==3.7.2
11
+ pathlib
12
+ python-dateutil==2.8.2
13
+
14
+ # Additional dependencies
15
+ pyyaml>=6.0
16
+ requests>=2.31.0
17
+ scipy>=1.11.0
18
+ pandas>=2.0.3
19
+ tqdm>=4.65.0
20
+ seaborn>=0.12.2
21
+
22
+ # For better video codec support
23
+ imageio
24
+ imageio-ffmpeg
25
+
26
+ # System utilities
27
+ psutil>=5.9.0
28
+ plotly
29
+ torch
send_discord.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import cv2
3
+ import numpy as np
4
+ from datetime import datetime
5
+ from pathlib import Path
6
+ import logging
7
+ import base64
8
+ import io
9
+ import json
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class DiscordAlertManager:
15
+ """Manages Discord webhook alerts for hygiene violations."""
16
+
17
+ def __init__(self, discord_config: dict):
18
+ """
19
+ discord_config: {
20
+ 'webhook_url': 'your_webhook_url'
21
+ }
22
+ """
23
+ self.webhook_url = discord_config['webhook_url']
24
+ self.alert_cooldown = 300
25
+ self.last_alert_time = None
26
+ self.dirty_start_time = None
27
+ self.dirty_threshold_seconds = 10
28
+ self.dirty_coverage_threshold = 0.06
29
+
30
+ def should_send_alert(self, dirty_coverage: float, current_time: datetime) -> bool:
31
+ """Same logic as before"""
32
+ if dirty_coverage < self.dirty_coverage_threshold:
33
+ self.dirty_start_time = None
34
+ return False
35
+
36
+ if self.dirty_start_time is None:
37
+ self.dirty_start_time = current_time
38
+ return False
39
+
40
+ dirty_duration = (current_time - self.dirty_start_time).total_seconds()
41
+ if dirty_duration < self.dirty_threshold_seconds:
42
+ return False
43
+
44
+ if self.last_alert_time is not None:
45
+ time_since_last = (current_time - self.last_alert_time).total_seconds()
46
+ if time_since_last < self.alert_cooldown:
47
+ return False
48
+
49
+ return True
50
+
51
+ def generate_heatmap_image(self, frame: np.ndarray, gmm_heatmap: np.ndarray,
52
+ output_path: str) -> str:
53
+ """Generate heatmap visualization"""
54
+ result = frame.copy()
55
+ height, width = result.shape[:2]
56
+
57
+ if gmm_heatmap.shape != (height, width):
58
+ heatmap_resized = cv2.resize(gmm_heatmap, (width, height))
59
+ else:
60
+ heatmap_resized = gmm_heatmap
61
+
62
+ heatmap_colored = cv2.applyColorMap(
63
+ (heatmap_resized * 255).astype(np.uint8),
64
+ cv2.COLORMAP_JET
65
+ )
66
+
67
+ alpha = 0.5
68
+ result = cv2.addWeighted(frame, 1 - alpha, heatmap_colored, alpha, 0)
69
+
70
+ # Add info panel
71
+ avg_dirt = np.mean(heatmap_resized)
72
+ max_dirt = np.max(heatmap_resized)
73
+ dirty_pixels = np.sum(heatmap_resized > 0.60)
74
+ coverage_percent = (dirty_pixels / heatmap_resized.size) * 100
75
+
76
+ cv2.rectangle(result, (10, 10), (400, 120), (0, 0, 0), -1)
77
+ cv2.rectangle(result, (10, 10), (400, 120), (255, 255, 255), 2)
78
+
79
+ info_text = [
80
+ f"Average Dirt: {avg_dirt:.2f}",
81
+ f"Maximum Dirt: {max_dirt:.2f}",
82
+ f"Red Zone: {coverage_percent:.1f}%",
83
+ f"Time: {datetime.now().strftime('%H:%M:%S')}"
84
+ ]
85
+
86
+ for i, text in enumerate(info_text):
87
+ cv2.putText(result, text, (20, 35 + i * 25),
88
+ cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 1)
89
+
90
+ cv2.imwrite(output_path, result)
91
+ return output_path
92
+
93
+ def send_alert(self, camera_name: str, dirty_coverage: float,
94
+ dirty_duration: int, frame: np.ndarray,
95
+ gmm_heatmap: np.ndarray) -> bool:
96
+ """Send Discord webhook alert with embedded image"""
97
+ try:
98
+ # Generate image
99
+ temp_image_path = f"tmp/heatmap_{datetime.now().timestamp()}.png"
100
+ self.generate_heatmap_image(frame, gmm_heatmap, temp_image_path)
101
+
102
+ # Calculate duration
103
+ duration_mins = dirty_duration // 60
104
+ duration_secs = dirty_duration % 60
105
+
106
+ # Create rich embed
107
+ embed = {
108
+ "title": "🚨 CLEANING ALERT",
109
+ "description": f"**{camera_name}** requires immediate attention!",
110
+ "color": 15158332, # Red color (#E74C3C)
111
+ "fields": [
112
+ {
113
+ "name": "📍 Location",
114
+ "value": camera_name,
115
+ "inline": True
116
+ },
117
+ {
118
+ "name": "🔴 Coverage",
119
+ "value": f"{dirty_coverage*100:.1f}%",
120
+ "inline": True
121
+ },
122
+ {
123
+ "name": "⏱ Duration",
124
+ "value": f"{duration_mins}m {duration_secs}s",
125
+ "inline": True
126
+ },
127
+ {
128
+ "name": "⚠️ Action Required",
129
+ "value": "Table has exceeded cleanliness threshold and needs cleaning.",
130
+ "inline": False
131
+ }
132
+ ],
133
+ "footer": {
134
+ "text": "Kitchen Hygiene Monitoring System"
135
+ },
136
+ "timestamp": datetime.utcnow().isoformat()
137
+ }
138
+
139
+ # Prepare webhook payload with embeds
140
+ payload = {
141
+ "username": "Hygiene Monitor Bot",
142
+ "avatar_url": "https://cdn-icons-png.flaticon.com/512/3699/3699516.png",
143
+ "embeds": [embed]
144
+ }
145
+
146
+ # Read the image file
147
+ with open(temp_image_path, 'rb') as f:
148
+ image_data = f.read()
149
+
150
+ # Prepare the multipart form data
151
+ files = {
152
+ 'payload_json': (None, json.dumps(payload), 'application/json'),
153
+ 'file': ('heatmap.png', image_data, 'image/png')
154
+ }
155
+
156
+ # Send the request
157
+ response = requests.post(self.webhook_url, files=files)
158
+
159
+ if response.status_code in [200, 204]:
160
+ self.last_alert_time = datetime.now()
161
+ logger.info(f"✅ Discord alert sent for {camera_name}")
162
+ Path(temp_image_path).unlink(missing_ok=True)
163
+ return True
164
+ else:
165
+ logger.error(f"Discord webhook error: {response.status_code} - {response.text}")
166
+ return False
167
+
168
+ except Exception as e:
169
+ logger.error(f"Failed to send Discord alert: {str(e)}")
170
+ import traceback
171
+ logger.error(traceback.format_exc())
172
+ return False