mlbench123 commited on
Commit
0d571fe
·
verified ·
1 Parent(s): 8efc50e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1156 -1156
app.py CHANGED
@@ -1,1157 +1,1157 @@
1
- import cv2
2
- import torch
3
- import numpy as np
4
- from collections import deque
5
- from threading import Thread, Lock
6
- from queue import Queue
7
- import time
8
- import logging
9
- import os
10
- from datetime import datetime
11
- from PIL import Image
12
- from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor
13
- from fastapi import FastAPI, HTTPException, StreamingResponse
14
- from fastapi.responses import FileResponse, StreamingResponse
15
- import asyncio
16
- import uvicorn
17
- from pydantic import BaseModel
18
- from typing import Optional
19
- import requests
20
- from datetime import datetime, timedelta
21
-
22
- # ===== IMPORT THE DISCORD ALERT MANAGER =====
23
- from send_discord import DiscordAlertManager
24
-
25
- logging.basicConfig(level=logging.INFO)
26
- logger = logging.getLogger(__name__)
27
-
28
- # ==================== DATA MODELS ====================
29
-
30
- class StreamStartRequest(BaseModel):
31
- """Start streaming request."""
32
- rtmp_input_url: str
33
- camera_path: str # e.g., "models/cam1" - will auto-pick gmm_model.joblib and mask.png
34
-
35
-
36
- class StreamStopRequest(BaseModel):
37
- """Stop streaming request."""
38
- stream_id: str
39
-
40
-
41
- class StreamStatusResponse(BaseModel):
42
- """Stream status response."""
43
- stream_id: str
44
- status: str
45
- fps: float
46
- buffered_frames: int
47
- queue_size: int
48
-
49
-
50
- # ==================== CIRCULAR BUFFER ====================
51
-
52
- class CircularFrameBuffer:
53
- """Fixed-size buffer for storing processed frames."""
54
-
55
- def __init__(self, max_frames: int = 30):
56
- self.max_frames = max_frames
57
- self.frames = deque(maxlen=max_frames)
58
- self.lock = Lock()
59
- self.sequence_ids = deque(maxlen=max_frames)
60
-
61
- def add_frame(self, frame: np.ndarray, seq_id: int) -> None:
62
- """Add processed frame to buffer."""
63
- with self.lock:
64
- self.frames.append(frame.copy())
65
- self.sequence_ids.append(seq_id)
66
-
67
- def get_latest(self) -> tuple:
68
- """Get most recent frame."""
69
- with self.lock:
70
- if len(self.frames) > 0:
71
- return self.frames[-1].copy(), self.sequence_ids[-1]
72
- return None, None
73
-
74
- def clear(self) -> None:
75
- """Clear buffer."""
76
- with self.lock:
77
- self.frames.clear()
78
- self.sequence_ids.clear()
79
-
80
-
81
- # ==================== LIVE MONITOR ====================
82
-
83
- class LiveHygieneMonitor:
84
- """Production-ready hygiene monitor for live streams."""
85
-
86
- def __init__(self, segformer_path: str, max_buffer_frames: int = 30):
87
- self.segformer_path = segformer_path
88
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
89
-
90
- # Model loading
91
- self.model = None
92
- self.processor = None
93
- self._load_segformer()
94
-
95
- # GMM components
96
- self.gmm_model = None
97
- self.gmm_heatmap = None
98
- self.table_mask = None
99
-
100
- # Live streaming state
101
- self.frame_buffer = CircularFrameBuffer(max_frames=max_buffer_frames)
102
- self.input_queue = Queue(maxsize=5)
103
- self.processing_thread = None
104
- self.is_running = False
105
-
106
- # Frame sequence tracking
107
- self.frame_sequence = 0
108
- self.frame_lock = Lock()
109
-
110
- # State management
111
- self.detection_frames_count = 0
112
- self.no_detection_frames_count = 0
113
- self.cleaning_active = False
114
- self.cleaning_start_threshold = 4
115
- self.cleaning_stop_threshold = 12
116
-
117
- # Performance tracking
118
- self.frame_times = deque(maxlen=30)
119
- self.last_frame_time = time.time()
120
-
121
- # Optimization flags
122
- self.skip_segformer_every_n_frames = 2
123
- self.segformer_skip_counter = 0
124
- self.last_cloth_mask = None
125
-
126
- # Visualization settings
127
- self.show_cloth_detection = True
128
- self.erasure_radius_factor = 0.2
129
- self.gaussian_sigma_factor = 0.8
130
-
131
- self.tracker = None
132
- self.track_trajectories = {}
133
- self.max_trajectory_length = 40
134
- self.track_colors = {}
135
-
136
- # Alert manager - ADD THIS
137
- self.alert_manager = None
138
- self.current_camera_name = "Default Camera"
139
-
140
- logger.info(f"Live Monitor initialized on {self.device}")
141
-
142
- def _load_segformer(self):
143
- """Load SegFormer model."""
144
- try:
145
- self.model = SegformerForSemanticSegmentation.from_pretrained(self.segformer_path)
146
- self.processor = SegformerImageProcessor(do_reduce_labels=False)
147
- self.model.to(self.device)
148
- self.model.eval()
149
- logger.info(f"SegFormer loaded on {self.device}")
150
- except Exception as e:
151
- logger.error(f"Failed to load SegFormer: {e}")
152
-
153
- def _init_tracker(self):
154
- """Lazy-init tracker."""
155
- if self.tracker is None:
156
- from deep_sort_realtime.deepsort_tracker import DeepSort
157
- self.tracker = DeepSort(
158
- max_age=15,
159
- n_init=2,
160
- nms_max_overlap=0.7,
161
- max_cosine_distance=0.4,
162
- nn_budget=50,
163
- embedder="mobilenet",
164
- half=True,
165
- embedder_gpu=torch.cuda.is_available()
166
- )
167
-
168
- def load_gmm_model(self, gmm_path: str) -> bool:
169
- """Load GMM model."""
170
- try:
171
- from GMM import GMM
172
- self.gmm_model = GMM.load_model(gmm_path)
173
- if self.gmm_model.img_shape:
174
- h, w = self.gmm_model.img_shape[:2]
175
- self.gmm_heatmap = np.zeros((h, w), dtype=np.float32)
176
- logger.info("GMM model loaded")
177
- return True
178
- except Exception as e:
179
- logger.error(f"Failed to load GMM: {e}")
180
- return False
181
-
182
- def load_table_mask(self, mask_path: str) -> bool:
183
- """Load table mask."""
184
- try:
185
- mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
186
- self.table_mask = (mask > 128).astype(np.uint8)
187
- logger.info(f"Table mask loaded: {mask.shape}")
188
- return True
189
- except Exception as e:
190
- logger.error(f"Failed to load mask: {e}")
191
- return False
192
-
193
- def add_frame(self, frame: np.ndarray) -> None:
194
- """Add incoming frame (non-blocking)."""
195
- try:
196
- self.input_queue.put_nowait(frame)
197
- except:
198
- pass
199
-
200
- def start_processing(self) -> None:
201
- """Start background processing."""
202
- if self.is_running:
203
- return
204
- self.is_running = True
205
- self.processing_thread = Thread(target=self._process_loop, daemon=True)
206
- self.processing_thread.start()
207
- logger.info("Processing thread started")
208
-
209
- def stop_processing(self) -> None:
210
- """Stop processing."""
211
- self.is_running = False
212
- if self.processing_thread:
213
- self.processing_thread.join(timeout=5)
214
- self.frame_buffer.clear()
215
- logger.info("Processing stopped")
216
-
217
- def _get_next_sequence_id(self) -> int:
218
- """Thread-safe sequence ID."""
219
- with self.frame_lock:
220
- self.frame_sequence += 1
221
- return self.frame_sequence
222
-
223
- def _process_loop(self) -> None:
224
- """Main processing loop."""
225
- while self.is_running:
226
- try:
227
- frame = self.input_queue.get(timeout=1)
228
- seq_id = self._get_next_sequence_id()
229
-
230
- frame = self._resize_frame(frame, target_width=1024)
231
- cloth_mask = self._detect_cloth_fast(frame)
232
- cleaning_status = self._update_cleaning_status(cloth_mask)
233
-
234
- tracks = None
235
- if self.cleaning_active:
236
- self._init_tracker()
237
- tracks = self._track_cloth(frame, cloth_mask)
238
-
239
- self._update_gmm_fast(frame, cloth_mask, tracks)
240
- viz_frame = self._create_visualization(frame, cloth_mask, tracks, cleaning_status)
241
- self.frame_buffer.add_frame(viz_frame, seq_id)
242
-
243
- elapsed = time.time() - self.last_frame_time
244
- self.frame_times.append(elapsed)
245
- self.last_frame_time = time.time()
246
-
247
- if seq_id % 30 == 0:
248
- avg_time = np.mean(self.frame_times)
249
- fps = 1.0 / avg_time if avg_time > 0 else 0
250
- logger.info(f"Seq {seq_id} | {fps:.1f} FPS | {cleaning_status}")
251
-
252
- except Exception as e:
253
- logger.error(f"Processing error: {e}")
254
- continue
255
-
256
- def _resize_frame(self, frame: np.ndarray, target_width: int = 1024) -> np.ndarray:
257
- """Resize frame."""
258
- h, w = frame.shape[:2]
259
- if w > target_width:
260
- scale = target_width / w
261
- new_h = int(h * scale)
262
- return cv2.resize(frame, (target_width, new_h))
263
- return frame
264
-
265
- def _detect_cloth_fast(self, frame: np.ndarray) -> np.ndarray:
266
- """Fast cloth detection with skipping."""
267
- if self.model is None:
268
- return np.zeros((frame.shape[0], frame.shape[1]), dtype=np.uint8)
269
-
270
- self.segformer_skip_counter += 1
271
- if self.segformer_skip_counter < self.skip_segformer_every_n_frames:
272
- if self.last_cloth_mask is not None:
273
- return self.last_cloth_mask
274
- return np.zeros((frame.shape[0], frame.shape[1]), dtype=np.uint8)
275
-
276
- self.segformer_skip_counter = 0
277
-
278
- try:
279
- height, width = frame.shape[:2]
280
- frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
281
- pil_image = Image.fromarray(frame_rgb)
282
-
283
- with torch.no_grad():
284
- inputs = self.processor(images=pil_image, return_tensors="pt")
285
- inputs = {k: v.to(self.device) for k, v in inputs.items()}
286
- outputs = self.model(**inputs)
287
- logits = outputs.logits
288
-
289
- upsampled = torch.nn.functional.interpolate(
290
- logits, size=(height, width), mode="bilinear", align_corners=False
291
- )
292
-
293
- cloth_mask = (upsampled.argmax(dim=1)[0].cpu().numpy() == 1).astype(np.uint8)
294
-
295
- if self.table_mask is not None:
296
- if self.table_mask.shape != cloth_mask.shape:
297
- table_resized = cv2.resize(self.table_mask, (width, height))
298
- else:
299
- table_resized = self.table_mask
300
- cloth_mask = cloth_mask * table_resized
301
-
302
- self.last_cloth_mask = cloth_mask
303
- return cloth_mask
304
-
305
- except Exception as e:
306
- logger.error(f"Cloth detection error: {e}")
307
- return np.zeros((frame.shape[0], frame.shape[1]), dtype=np.uint8)
308
-
309
- def _track_cloth(self, frame: np.ndarray, cloth_mask: np.ndarray) -> list:
310
- """Fast tracking."""
311
- if self.tracker is None:
312
- return []
313
-
314
- try:
315
- contours, _ = cv2.findContours(cloth_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
316
- detections = []
317
-
318
- for contour in contours:
319
- area = cv2.contourArea(contour)
320
- if area < 150:
321
- continue
322
- x, y, w, h = cv2.boundingRect(contour)
323
- if w > 0 and h > 0:
324
- detections.append(([x, y, w, h], 0.95, 'cloth'))
325
-
326
- if not detections:
327
- return []
328
-
329
- tracks = self.tracker.update_tracks(detections, frame=frame)
330
-
331
- height, width = frame.shape[:2]
332
- for track in tracks:
333
- if not track.is_confirmed():
334
- continue
335
-
336
- track_id = track.track_id
337
- bbox = track.to_ltrb()
338
- cx = int((bbox[0] + bbox[2]) / 2)
339
- cy = int((bbox[1] + bbox[3]) / 2)
340
-
341
- if 0 <= cx < width and 0 <= cy < height:
342
- if track_id not in self.track_trajectories:
343
- self.track_trajectories[track_id] = deque(maxlen=self.max_trajectory_length)
344
- self.track_colors[track_id] = (255, 255, 0)
345
- self.track_trajectories[track_id].append((cx, cy))
346
-
347
- active_ids = {track.track_id for track in tracks if track.is_confirmed()}
348
- dead_ids = set(self.track_trajectories.keys()) - active_ids
349
- for dead_id in dead_ids:
350
- self.track_trajectories.pop(dead_id, None)
351
- self.track_colors.pop(dead_id, None)
352
-
353
- return tracks
354
-
355
- except Exception as e:
356
- logger.error(f"Tracking error: {e}")
357
- return []
358
-
359
- def _update_gmm_fast(self, frame: np.ndarray, cloth_mask: np.ndarray, tracks: list) -> None:
360
- """Lightweight GMM update."""
361
- if self.gmm_model is None:
362
- return
363
-
364
- try:
365
- height, width = frame.shape[:2]
366
- table_mask = None
367
- if self.table_mask is not None:
368
- if self.table_mask.shape != (height, width):
369
- table_mask = cv2.resize(self.table_mask, (width, height))
370
- else:
371
- table_mask = self.table_mask
372
-
373
- _, self.gmm_heatmap = self.gmm_model.infer(
374
- frame, heatmap=self.gmm_heatmap,
375
- alpha_start=0.008, alpha_end=0.0004,
376
- table_mask=table_mask
377
- )
378
-
379
- if self.cleaning_active and tracks:
380
- for track in tracks:
381
- if not track.is_confirmed():
382
- continue
383
-
384
- track_id = track.track_id
385
- if track_id not in self.track_trajectories:
386
- continue
387
-
388
- trajectory = list(self.track_trajectories[track_id])
389
- if len(trajectory) < 2:
390
- continue
391
-
392
- bbox = track.to_ltrb()
393
- w = bbox[2] - bbox[0]
394
- h = bbox[3] - bbox[1]
395
-
396
- radius = int(min(w, h) * self.erasure_radius_factor)
397
- radius = max(radius, 12)
398
-
399
- if radius <= 0 or w <= 0 or h <= 0:
400
- continue
401
-
402
- for i in range(len(trajectory) - 1):
403
- self._erase_at_point(trajectory[i], radius, table_mask)
404
-
405
- except Exception as e:
406
- logger.error(f"GMM update error: {e}")
407
-
408
- def _erase_at_point(self, point: tuple, radius: int, table_mask: np.ndarray) -> None:
409
- """Fast point-based erasure."""
410
- if self.gmm_heatmap is None or radius <= 0:
411
- return
412
-
413
- x, y = point
414
- height, width = self.gmm_heatmap.shape
415
-
416
- y_min = max(0, y - radius)
417
- y_max = min(height, y + radius)
418
- x_min = max(0, x - radius)
419
- x_max = min(width, x + radius)
420
-
421
- if y_min >= y_max or x_min >= x_max:
422
- return
423
-
424
- y_indices, x_indices = np.ogrid[y_min:y_max, x_min:x_max]
425
- distance_sq = (x_indices - x)**2 + (y_indices - y)**2
426
-
427
- gaussian = np.exp(-distance_sq / (2 * (radius * self.gaussian_sigma_factor)**2))
428
-
429
- if table_mask is not None:
430
- gaussian = gaussian * table_mask[y_min:y_max, x_min:x_max]
431
-
432
- decay = 0.025 * gaussian
433
- self.gmm_heatmap[y_min:y_max, x_min:x_max] = np.maximum(
434
- 0, self.gmm_heatmap[y_min:y_max, x_min:x_max] - decay
435
- )
436
-
437
- def _update_cleaning_status(self, cloth_mask: np.ndarray) -> str:
438
- """Update cleaning status."""
439
- has_cloth = np.sum(cloth_mask) > 100
440
-
441
- if has_cloth:
442
- self.detection_frames_count += 1
443
- self.no_detection_frames_count = 0
444
- else:
445
- self.no_detection_frames_count += 1
446
- self.detection_frames_count = 0
447
-
448
- if not self.cleaning_active and self.detection_frames_count >= self.cleaning_start_threshold:
449
- self.cleaning_active = True
450
- return "CLEANING STARTED"
451
- elif self.cleaning_active and self.no_detection_frames_count >= self.cleaning_stop_threshold:
452
- self.cleaning_active = False
453
- return "CLEANING STOPPED"
454
-
455
- return "CLEANING ACTIVE" if self.cleaning_active else "NO CLEANING"
456
-
457
- def _create_visualization(self, frame: np.ndarray, cloth_mask: np.ndarray,
458
- tracks: list, cleaning_status: str) -> np.ndarray:
459
- """Create fast visualization."""
460
- result = frame.copy()
461
-
462
- if np.sum(cloth_mask) > 0:
463
- overlay = result.copy()
464
- cloth_pixels = cloth_mask > 0
465
- overlay[cloth_pixels] = [0, 255, 0]
466
- result[cloth_pixels] = cv2.addWeighted(
467
- frame[cloth_pixels], 0.7, overlay[cloth_pixels], 0.3, 0
468
- )
469
-
470
- if self.gmm_heatmap is not None and self.gmm_heatmap.max() > 0:
471
- height, width = result.shape[:2]
472
- heatmap_resized = cv2.resize(self.gmm_heatmap, (width, height))
473
- heatmap_colored = cv2.applyColorMap(
474
- (heatmap_resized * 255).astype(np.uint8), cv2.COLORMAP_JET
475
- )
476
- significant = heatmap_resized > 0.1
477
- result[significant] = cv2.addWeighted(
478
- frame[significant], 0.6, heatmap_colored[significant], 0.4, 0
479
- )
480
-
481
- if tracks:
482
- for track in tracks:
483
- if track.is_confirmed():
484
- bbox = track.to_ltrb()
485
- cx, cy = int((bbox[0] + bbox[2])/2), int((bbox[1] + bbox[3])/2)
486
- cv2.circle(result, (cx, cy), 4, (0, 0, 255), -1)
487
-
488
- status_color = (0, 255, 0) if "ACTIVE" in cleaning_status else (150, 150, 150)
489
- cv2.putText(result, cleaning_status, (20, 40),
490
- cv2.FONT_HERSHEY_SIMPLEX, 0.7, status_color, 2)
491
-
492
- return result
493
-
494
- def get_latest_frame(self) -> np.ndarray:
495
- """Get latest processed frame."""
496
- frame, _ = self.frame_buffer.get_latest()
497
- return frame
498
-
499
- def get_stats(self) -> dict:
500
- """Get stats."""
501
- with self.frame_buffer.lock:
502
- avg_time = np.mean(self.frame_times) if len(self.frame_times) > 0 else 0.033
503
- fps = 1.0 / avg_time if avg_time > 0 else 0
504
- return {
505
- "buffered_frames": len(self.frame_buffer.frames),
506
- "avg_fps": fps,
507
- "queue_size": self.input_queue.qsize(),
508
- "is_running": self.is_running
509
- }
510
-
511
-
512
- # ==================== FASTAPI APP ====================
513
-
514
- app = FastAPI(title="Hygiene Monitor Live Stream", version="1.0.0")
515
-
516
- # Active streams: {stream_id: {"monitor": LiveHygieneMonitor, "cap": VideoCapture, "thread": Thread}}
517
- active_streams = {}
518
- streams_lock = Lock()
519
-
520
-
521
- def _get_model_files(camera_path: str) -> tuple:
522
- """Extract GMM and mask paths from camera directory."""
523
- if not os.path.isdir(camera_path):
524
- raise ValueError(f"Camera path not found: {camera_path}")
525
-
526
- gmm_path = os.path.join(camera_path, "gmm_model.joblib")
527
- mask_path = os.path.join(camera_path, "mask.png")
528
-
529
- if not os.path.exists(gmm_path):
530
- raise ValueError(f"GMM model not found: {gmm_path}")
531
- if not os.path.exists(mask_path):
532
- raise ValueError(f"Mask not found: {mask_path}")
533
-
534
- return gmm_path, mask_path
535
-
536
-
537
- def _stream_worker(stream_id: str, rtmp_url: str, gmm_path: str, mask_path: str):
538
- """Background worker for streaming."""
539
- try:
540
- monitor = LiveHygieneMonitor(
541
- segformer_path="models/segformer_model",
542
- max_buffer_frames=30
543
- )
544
-
545
- if not monitor.load_gmm_model(gmm_path):
546
- logger.error(f"[{stream_id}] Failed to load GMM model")
547
- return
548
-
549
- if not monitor.load_table_mask(mask_path):
550
- logger.error(f"[{stream_id}] Failed to load mask")
551
- return
552
-
553
- # === INITIALIZE ALERT MANAGER - ADD THIS ===
554
- webhook_url = os.getenv("DISCORD_WEBHOOK_URL") # From environment
555
- if webhook_url:
556
- monitor.alert_manager = DiscordAlertManager(webhook_url=webhook_url)
557
- monitor.current_camera_name = stream_id # Or pass from request
558
- logger.info(f"[{stream_id}] Alert manager initialized")
559
-
560
- monitor.start_processing()
561
-
562
- cap = cv2.VideoCapture(rtmp_url)
563
- if not cap.isOpened():
564
- logger.error(f"[{stream_id}] Failed to connect to RTMP: {rtmp_url}")
565
- monitor.stop_processing()
566
- return
567
-
568
- # Update active stream
569
- with streams_lock:
570
- if stream_id in active_streams:
571
- active_streams[stream_id]["monitor"] = monitor
572
- active_streams[stream_id]["cap"] = cap
573
- active_streams[stream_id]["connected"] = True
574
-
575
- frame_count = 0
576
- logger.info(f"[{stream_id}] Connected to {rtmp_url}")
577
-
578
- while True:
579
- with streams_lock:
580
- if stream_id not in active_streams or not active_streams[stream_id]["running"]:
581
- break
582
-
583
- ret, frame = cap.read()
584
- if not ret:
585
- logger.warning(f"[{stream_id}] RTMP connection lost, reconnecting...")
586
- cap.release()
587
- time.sleep(2)
588
- cap = cv2.VideoCapture(rtmp_url)
589
- continue
590
-
591
- monitor.add_frame(frame)
592
- frame_count += 1
593
-
594
- if frame_count % 100 == 0:
595
- stats = monitor.get_stats()
596
- logger.info(f"[{stream_id}] Frames: {frame_count}, FPS: {stats['avg_fps']:.1f}")
597
-
598
- except Exception as e:
599
- logger.error(f"[{stream_id}] Stream error: {e}")
600
-
601
- finally:
602
- with streams_lock:
603
- if stream_id in active_streams:
604
- if active_streams[stream_id]["cap"]:
605
- active_streams[stream_id]["cap"].release()
606
- if active_streams[stream_id]["monitor"]:
607
- active_streams[stream_id]["monitor"].stop_processing()
608
- active_streams[stream_id]["connected"] = False
609
-
610
- logger.info(f"[{stream_id}] Stream closed")
611
-
612
-
613
- # ==================== ENDPOINTS ====================
614
-
615
- @app.post("/stream/start")
616
- async def start_stream(request: StreamStartRequest):
617
- """Start a new live stream."""
618
- stream_id = f"stream_{int(time.time() * 1000)}"
619
-
620
- try:
621
- # Extract model files from camera path
622
- gmm_path, mask_path = _get_model_files(request.camera_path)
623
-
624
- # Create stream entry
625
- with streams_lock:
626
- active_streams[stream_id] = {
627
- "running": True,
628
- "connected": False,
629
- "monitor": None,
630
- "cap": None,
631
- "thread": None,
632
- "camera_path": request.camera_path
633
- }
634
-
635
- # Start background worker thread
636
- thread = Thread(
637
- target=_stream_worker,
638
- args=(stream_id, request.rtmp_input_url, gmm_path, mask_path),
639
- daemon=True
640
- )
641
- thread.start()
642
-
643
- with streams_lock:
644
- active_streams[stream_id]["thread"] = thread
645
-
646
- logger.info(f"Stream {stream_id} started")
647
- return {
648
- "stream_id": stream_id,
649
- "status": "starting",
650
- "message": f"Stream {stream_id} is starting, will connect to {request.rtmp_input_url}"
651
- }
652
-
653
- except Exception as e:
654
- logger.error(f"Failed to start stream: {e}")
655
- raise HTTPException(status_code=400, detail=str(e))
656
-
657
-
658
- @app.post("/stream/stop")
659
- async def stop_stream(request: StreamStopRequest):
660
- """Stop a live stream."""
661
- stream_id = request.stream_id
662
-
663
- with streams_lock:
664
- if stream_id not in active_streams:
665
- raise HTTPException(status_code=404, detail=f"Stream {stream_id} not found")
666
-
667
- active_streams[stream_id]["running"] = False
668
-
669
- logger.info(f"Stream {stream_id} stop requested")
670
- return {"stream_id": stream_id, "status": "stopping"}
671
-
672
-
673
- @app.get("/stream/status/{stream_id}")
674
- async def get_stream_status(stream_id: str):
675
- """Get stream status."""
676
- with streams_lock:
677
- if stream_id not in active_streams:
678
- raise HTTPException(status_code=404, detail=f"Stream {stream_id} not found")
679
-
680
- stream_data = active_streams[stream_id]
681
- monitor = stream_data["monitor"]
682
-
683
- stats = monitor.get_stats() if monitor else {}
684
-
685
- return {
686
- "stream_id": stream_id,
687
- "connected": stream_data["connected"],
688
- "running": stream_data["running"],
689
- "camera_path": stream_data["camera_path"],
690
- "fps": stats.get("avg_fps", 0),
691
- "buffered_frames": stats.get("buffered_frames", 0),
692
- "queue_size": stats.get("queue_size", 0)
693
- }
694
-
695
-
696
- @app.get("/stream/video/{stream_id}")
697
- async def stream_video(stream_id: str):
698
- """Stream video frames via MJPEG."""
699
- with streams_lock:
700
- if stream_id not in active_streams:
701
- raise HTTPException(status_code=404, detail=f"Stream {stream_id} not found")
702
-
703
- monitor = active_streams[stream_id]["monitor"]
704
-
705
- if not monitor:
706
- raise HTTPException(status_code=503, detail="Monitor not ready")
707
-
708
- async def frame_generator():
709
- while True:
710
- with streams_lock:
711
- if stream_id not in active_streams or not active_streams[stream_id]["running"]:
712
- break
713
-
714
- frame = monitor.get_latest_frame()
715
- if frame is not None:
716
- _, buffer = cv2.imencode('.jpg', frame, [cv2.IMWRITE_JPEG_QUALITY, 80])
717
- yield (b'--frame\r\n'
718
- b'Content-Type: image/jpeg\r\n'
719
- b'Content-Length: ' + str(len(buffer)).encode() + b'\r\n\r\n'
720
- + buffer.tobytes() + b'\r\n')
721
- else:
722
- await asyncio.sleep(0.01)
723
-
724
- return StreamingResponse(
725
- frame_generator(),
726
- media_type="multipart/x-mixed-replace; boundary=frame"
727
- )
728
-
729
-
730
- @app.get("/streams")
731
- async def list_streams():
732
- """List all active streams."""
733
- with streams_lock:
734
- streams_list = []
735
- for stream_id, data in active_streams.items():
736
- monitor = data["monitor"]
737
- stats = monitor.get_stats() if monitor else {}
738
-
739
- streams_list.append({
740
- "stream_id": stream_id,
741
- "connected": data["connected"],
742
- "running": data["running"],
743
- "camera_path": data["camera_path"],
744
- "fps": stats.get("avg_fps", 0),
745
- "buffered_frames": stats.get("buffered_frames", 0)
746
- })
747
-
748
- return {"total_streams": len(streams_list), "streams": streams_list}
749
-
750
-
751
- @app.post("/stream/restart/{stream_id}")
752
- async def restart_stream(stream_id: str):
753
- """Restart a stream."""
754
- with streams_lock:
755
- if stream_id not in active_streams:
756
- raise HTTPException(status_code=404, detail=f"Stream {stream_id} not found")
757
-
758
- active_streams[stream_id]["running"] = False
759
-
760
- await asyncio.sleep(2)
761
-
762
- with streams_lock:
763
- data = active_streams[stream_id]
764
- data["running"] = True
765
-
766
- return {"stream_id": stream_id, "status": "restarting"}
767
-
768
- @app.post("/camera/extract_frame")
769
- async def extract_frame_from_rtmp(request: dict):
770
- """
771
- Extract a single frame from RTMP stream for corner selection.
772
-
773
- Request body:
774
- {
775
- "rtmp_url": "rtmp://192.168.1.100:1935/live/kitchen",
776
- "camera_name": "kitchen"
777
- }
778
-
779
- Returns:
780
- {
781
- "success": true,
782
- "frame_base64": "base64_encoded_image",
783
- "frame_dimensions": {"width": 1920, "height": 1080}
784
- }
785
- """
786
- try:
787
- rtmp_url = request.get("rtmp_url")
788
- camera_name = request.get("camera_name")
789
-
790
- if not rtmp_url or not camera_name:
791
- raise HTTPException(status_code=400, detail="Missing rtmp_url or camera_name")
792
-
793
- # Connect to RTMP stream
794
- cap = cv2.VideoCapture(rtmp_url)
795
- if not cap.isOpened():
796
- raise HTTPException(status_code=400, detail=f"Failed to connect to RTMP: {rtmp_url}")
797
-
798
- # Read first frame
799
- ret, frame = cap.read()
800
- cap.release()
801
-
802
- if not ret:
803
- raise HTTPException(status_code=400, detail="Failed to read frame from RTMP stream")
804
- import base64
805
- # Convert frame to base64 for frontend display
806
- _, buffer = cv2.imencode('.jpg', frame, [cv2.IMWRITE_JPEG_QUALITY, 95])
807
- frame_base64 = base64.b64encode(buffer).decode('utf-8')
808
-
809
- # Store frame temporarily for training (optional - could store in memory cache)
810
- temp_dir = "temp_frames"
811
- os.makedirs(temp_dir, exist_ok=True)
812
- temp_frame_path = os.path.join(temp_dir, f"{camera_name}_reference.jpg")
813
- cv2.imwrite(temp_frame_path, frame)
814
-
815
- return {
816
- "success": True,
817
- "frame_base64": frame_base64,
818
- "frame_dimensions": {
819
- "width": frame.shape[1],
820
- "height": frame.shape[0]
821
- },
822
- "temp_frame_path": temp_frame_path
823
- }
824
-
825
- except Exception as e:
826
- logger.error(f"Extract frame error: {e}")
827
- raise HTTPException(status_code=500, detail=str(e))
828
-
829
-
830
- @app.post("/camera/train_gmm")
831
- async def train_gmm_from_rtmp(request: dict):
832
- """
833
- Train GMM model from RTMP stream using N corner points (minimum 4).
834
-
835
- Request body:
836
- {
837
- "rtmp_url": "rtmp://192.168.1.100:1935/live/kitchen",
838
- "camera_name": "kitchen",
839
- "corner_points": [
840
- {"x": 100, "y": 50},
841
- {"x": 400, "y": 45},
842
- {"x": 700, "y": 55},
843
- {"x": 800, "y": 60},
844
- {"x": 850, "y": 300},
845
- {"x": 850, "y": 600},
846
- {"x": 400, "y": 620},
847
- {"x": 50, "y": 580},
848
- {"x": 45, "y": 300}
849
- ], // Can be 4+ points for curved tables
850
- "max_frames": 250,
851
- "use_perspective_warp": false // NEW: Set false for non-rectangular tables
852
- }
853
- """
854
- try:
855
- rtmp_url = request.get("rtmp_url")
856
- camera_name = request.get("camera_name")
857
- corner_points = request.get("corner_points")
858
- max_frames = request.get("max_frames", 250)
859
- use_perspective_warp = request.get("use_perspective_warp", False) # NEW
860
-
861
- # Validation
862
- if not rtmp_url or not camera_name or not corner_points:
863
- raise HTTPException(status_code=400, detail="Missing required parameters")
864
-
865
- if len(corner_points) < 4:
866
- raise HTTPException(status_code=400, detail="Minimum 4 corner points required")
867
-
868
- logger.info(f"Starting GMM training for camera: {camera_name} with {len(corner_points)} points")
869
-
870
- # ===== STEP 1: Connect to RTMP and capture frames =====
871
- cap = cv2.VideoCapture(rtmp_url)
872
- if not cap.isOpened():
873
- raise HTTPException(status_code=400, detail=f"Failed to connect to RTMP: {rtmp_url}")
874
-
875
- ret, first_frame = cap.read()
876
- if not ret:
877
- cap.release()
878
- raise HTTPException(status_code=400, detail="Failed to read from RTMP stream")
879
-
880
- h, w = first_frame.shape[:2]
881
-
882
- # ===== STEP 2: Create polygon mask from N points =====
883
- pts_polygon = np.array([
884
- [point['x'], point['y']] for point in corner_points
885
- ], dtype=np.int32)
886
-
887
- # Create binary mask for the table area
888
- table_mask = np.zeros((h, w), dtype=np.uint8)
889
- cv2.fillPoly(table_mask, [pts_polygon], 255)
890
-
891
- # ===== STEP 3: Decide transformation strategy =====
892
- import tempfile
893
- temp_dir = tempfile.mkdtemp()
894
- frame_count = 0
895
-
896
- if use_perspective_warp and len(corner_points) == 4:
897
- # ===== STRATEGY A: Perspective warp (rectangular tables only) =====
898
- logger.info("Using perspective warp for rectangular table")
899
-
900
- pts_src = np.array([
901
- [corner_points[0]['x'], corner_points[0]['y']],
902
- [corner_points[1]['x'], corner_points[1]['y']],
903
- [corner_points[2]['x'], corner_points[2]['y']],
904
- [corner_points[3]['x'], corner_points[3]['y']]
905
- ], dtype=np.float32)
906
-
907
- pts_dst = np.array([
908
- [0, 0], [w, 0], [w, h], [0, h]
909
- ], dtype=np.float32)
910
-
911
- matrix = cv2.getPerspectiveTransform(pts_src, pts_dst)
912
-
913
- # Capture and warp frames
914
- cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
915
- while frame_count < max_frames:
916
- ret, frame = cap.read()
917
- if not ret:
918
- break
919
-
920
- warped = cv2.warpPerspective(frame, matrix, (w, h))
921
- frame_path = os.path.join(temp_dir, f'b{frame_count:05d}.png')
922
- cv2.imwrite(frame_path, warped)
923
- frame_count += 1
924
-
925
- if frame_count % 50 == 0:
926
- logger.info(f"Captured {frame_count}/{max_frames} frames")
927
-
928
- # For warped images, mask should be full frame (already aligned)
929
- final_mask = np.ones((h, w), dtype=np.uint8) * 255
930
-
931
- else:
932
- # ===== STRATEGY B: Direct masking (curved/complex tables) =====
933
- logger.info(f"Using direct masking for {len(corner_points)}-point polygon (curved table)")
934
-
935
- # Capture frames WITHOUT warping, apply mask during inference
936
- cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
937
- while frame_count < max_frames:
938
- ret, frame = cap.read()
939
- if not ret:
940
- break
941
-
942
- # Apply mask to frame (zero out outside table area)
943
- masked_frame = cv2.bitwise_and(frame, frame, mask=table_mask)
944
-
945
- frame_path = os.path.join(temp_dir, f'b{frame_count:05d}.png')
946
- cv2.imwrite(frame_path, masked_frame)
947
- frame_count += 1
948
-
949
- if frame_count % 50 == 0:
950
- logger.info(f"Captured {frame_count}/{max_frames} frames")
951
-
952
- # Use original polygon mask
953
- final_mask = table_mask
954
-
955
- cap.release()
956
-
957
- if frame_count == 0:
958
- raise HTTPException(status_code=400, detail="No frames captured")
959
-
960
- logger.info(f"Captured {frame_count} frames, starting GMM training...")
961
-
962
- # ===== STEP 4: Train GMM =====
963
- from GMM import GMM
964
- gmm = GMM(temp_dir, frame_count, alpha=0.05)
965
- gmm.train(K=4)
966
- logger.info("GMM training complete")
967
-
968
- # ===== STEP 5: Save artifacts =====
969
- camera_path = os.path.join("models", camera_name)
970
- os.makedirs(camera_path, exist_ok=True)
971
-
972
- # 1. Save GMM model
973
- gmm_path = os.path.join(camera_path, "gmm_model.joblib")
974
- gmm.save_model(gmm_path)
975
-
976
- # 2. Save mask (polygon-based, not rectangular)
977
- mask_path = os.path.join(camera_path, "mask.png")
978
- cv2.imwrite(mask_path, final_mask)
979
- logger.info(f"Saved {len(corner_points)}-point polygon mask to {mask_path}")
980
-
981
- # 3. Create thumbnail with polygon overlay
982
- thumb_frame = first_frame.copy()
983
-
984
- # Draw filled polygon with transparency
985
- overlay = thumb_frame.copy()
986
- cv2.fillPoly(overlay, [pts_polygon], (0, 255, 0))
987
- cv2.addWeighted(thumb_frame, 0.7, overlay, 0.3, 0, thumb_frame)
988
-
989
- # Draw polygon border
990
- cv2.polylines(thumb_frame, [pts_polygon], True, (0, 255, 0), 3)
991
-
992
- # Draw corner points with numbers
993
- colors = [
994
- (255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0),
995
- (255, 0, 255), (0, 255, 255), (128, 0, 128), (255, 128, 0)
996
- ]
997
-
998
- for i, point in enumerate(corner_points):
999
- x, y = point['x'], point['y']
1000
- color = colors[i % len(colors)]
1001
-
1002
- cv2.circle(thumb_frame, (x, y), 8, color, -1)
1003
- cv2.circle(thumb_frame, (x, y), 10, (255, 255, 255), 2)
1004
-
1005
- # Point number
1006
- cv2.putText(thumb_frame, str(i+1), (x+15, y),
1007
- cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2)
1008
-
1009
- # Camera name label
1010
- cv2.putText(thumb_frame, camera_name, (30, 50),
1011
- cv2.FONT_HERSHEY_DUPLEX, 1.5, (255, 255, 255), 3)
1012
- cv2.putText(thumb_frame, camera_name, (30, 50),
1013
- cv2.FONT_HERSHEY_DUPLEX, 1.5, (0, 255, 0), 2)
1014
-
1015
- # Add point count indicator
1016
- cv2.putText(thumb_frame, f"{len(corner_points)} points", (30, 90),
1017
- cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 2)
1018
-
1019
- thumb_path = os.path.join(camera_path, "thumb.png")
1020
- cv2.imwrite(thumb_path, thumb_frame)
1021
-
1022
- # 4. Save polygon metadata (NEW - for reconstruction)
1023
- metadata = {
1024
- "camera_name": camera_name,
1025
- "num_points": len(corner_points),
1026
- "corner_points": corner_points,
1027
- "frame_dimensions": {"width": w, "height": h},
1028
- "use_perspective_warp": use_perspective_warp,
1029
- "training_date": datetime.now().isoformat()
1030
- }
1031
-
1032
- import json
1033
- metadata_path = os.path.join(camera_path, "metadata.json")
1034
- with open(metadata_path, 'w') as f:
1035
- json.dump(metadata, f, indent=2)
1036
-
1037
- logger.info(f"Saved metadata to {metadata_path}")
1038
-
1039
- # Cleanup
1040
- import shutil
1041
- shutil.rmtree(temp_dir)
1042
-
1043
- logger.info(f"✅ Camera '{camera_name}' training complete with {len(corner_points)}-point polygon!")
1044
-
1045
- return {
1046
- "success": True,
1047
- "camera_name": camera_name,
1048
- "camera_path": camera_path,
1049
- "frames_captured": frame_count,
1050
- "polygon_points": len(corner_points),
1051
- "use_perspective_warp": use_perspective_warp,
1052
- "model_files": {
1053
- "gmm_model": gmm_path,
1054
- "mask": mask_path,
1055
- "thumbnail": thumb_path,
1056
- "metadata": metadata_path
1057
- }
1058
- }
1059
-
1060
- except Exception as e:
1061
- logger.error(f"GMM training error: {e}")
1062
- import traceback
1063
- logger.error(traceback.format_exc())
1064
- raise HTTPException(status_code=500, detail=str(e))
1065
-
1066
-
1067
- @app.get("/cameras")
1068
- async def list_cameras():
1069
- """
1070
- List all trained cameras with their metadata.
1071
-
1072
- Returns:
1073
- {
1074
- "cameras": [
1075
- {
1076
- "name": "kitchen",
1077
- "path": "models/kitchen",
1078
- "thumbnail": "models/kitchen/thumb.png",
1079
- "has_gmm_model": true,
1080
- "has_mask": true
1081
- }
1082
- ]
1083
- }
1084
- """
1085
- try:
1086
- cameras = []
1087
- models_dir = "models"
1088
-
1089
- if not os.path.exists(models_dir):
1090
- return {"cameras": []}
1091
-
1092
- for camera_name in os.listdir(models_dir):
1093
- camera_path = os.path.join(models_dir, camera_name)
1094
-
1095
- if not os.path.isdir(camera_path):
1096
- continue
1097
-
1098
- gmm_path = os.path.join(camera_path, "gmm_model.joblib")
1099
- mask_path = os.path.join(camera_path, "mask.png")
1100
- thumb_path = os.path.join(camera_path, "thumb.png")
1101
-
1102
- cameras.append({
1103
- "name": camera_name,
1104
- "path": camera_path,
1105
- "thumbnail": thumb_path if os.path.exists(thumb_path) else None,
1106
- "has_gmm_model": os.path.exists(gmm_path),
1107
- "has_mask": os.path.exists(mask_path)
1108
- })
1109
-
1110
- return {"cameras": cameras}
1111
-
1112
- except Exception as e:
1113
- logger.error(f"List cameras error: {e}")
1114
- raise HTTPException(status_code=500, detail=str(e))
1115
-
1116
-
1117
- @app.delete("/camera/{camera_name}")
1118
- async def delete_camera(camera_name: str):
1119
- """
1120
- Delete a trained camera and all its files.
1121
- """
1122
- try:
1123
- camera_path = os.path.join("models", camera_name)
1124
-
1125
- if not os.path.exists(camera_path):
1126
- raise HTTPException(status_code=404, detail=f"Camera '{camera_name}' not found")
1127
-
1128
- import shutil
1129
- shutil.rmtree(camera_path)
1130
-
1131
- logger.info(f"Deleted camera: {camera_name}")
1132
-
1133
- return {
1134
- "success": True,
1135
- "message": f"Camera '{camera_name}' deleted successfully"
1136
- }
1137
-
1138
- except Exception as e:
1139
- logger.error(f"Delete camera error: {e}")
1140
- raise HTTPException(status_code=500, detail=str(e))
1141
-
1142
-
1143
- @app.get("/health")
1144
- async def health_check():
1145
- """Health check endpoint."""
1146
- with streams_lock:
1147
- stream_count = len(active_streams)
1148
-
1149
- return {
1150
- "status": "healthy",
1151
- "active_streams": stream_count,
1152
- "timestamp": datetime.now().isoformat()
1153
- }
1154
-
1155
-
1156
- if __name__ == "__main__":
1157
  uvicorn.run(app, host="0.0.0.0", port=8000)
 
1
+ import cv2
2
+ import torch
3
+ import numpy as np
4
+ from collections import deque
5
+ from threading import Thread, Lock
6
+ from queue import Queue
7
+ import time
8
+ import logging
9
+ import os
10
+ from datetime import datetime
11
+ from PIL import Image
12
+ from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor
13
+ from fastapi import FastAPI, HTTPException
14
+ from fastapi.responses import FileResponse, StreamingResponse
15
+ import asyncio
16
+ import uvicorn
17
+ from pydantic import BaseModel
18
+ from typing import Optional
19
+ import requests
20
+ from datetime import datetime, timedelta
21
+
22
+ # ===== IMPORT THE DISCORD ALERT MANAGER =====
23
+ from send_discord import DiscordAlertManager
24
+
25
+ logging.basicConfig(level=logging.INFO)
26
+ logger = logging.getLogger(__name__)
27
+
28
+ # ==================== DATA MODELS ====================
29
+
30
+ class StreamStartRequest(BaseModel):
31
+ """Start streaming request."""
32
+ rtmp_input_url: str
33
+ camera_path: str # e.g., "models/cam1" - will auto-pick gmm_model.joblib and mask.png
34
+
35
+
36
+ class StreamStopRequest(BaseModel):
37
+ """Stop streaming request."""
38
+ stream_id: str
39
+
40
+
41
+ class StreamStatusResponse(BaseModel):
42
+ """Stream status response."""
43
+ stream_id: str
44
+ status: str
45
+ fps: float
46
+ buffered_frames: int
47
+ queue_size: int
48
+
49
+
50
+ # ==================== CIRCULAR BUFFER ====================
51
+
52
+ class CircularFrameBuffer:
53
+ """Fixed-size buffer for storing processed frames."""
54
+
55
+ def __init__(self, max_frames: int = 30):
56
+ self.max_frames = max_frames
57
+ self.frames = deque(maxlen=max_frames)
58
+ self.lock = Lock()
59
+ self.sequence_ids = deque(maxlen=max_frames)
60
+
61
+ def add_frame(self, frame: np.ndarray, seq_id: int) -> None:
62
+ """Add processed frame to buffer."""
63
+ with self.lock:
64
+ self.frames.append(frame.copy())
65
+ self.sequence_ids.append(seq_id)
66
+
67
+ def get_latest(self) -> tuple:
68
+ """Get most recent frame."""
69
+ with self.lock:
70
+ if len(self.frames) > 0:
71
+ return self.frames[-1].copy(), self.sequence_ids[-1]
72
+ return None, None
73
+
74
+ def clear(self) -> None:
75
+ """Clear buffer."""
76
+ with self.lock:
77
+ self.frames.clear()
78
+ self.sequence_ids.clear()
79
+
80
+
81
+ # ==================== LIVE MONITOR ====================
82
+
83
+ class LiveHygieneMonitor:
84
+ """Production-ready hygiene monitor for live streams."""
85
+
86
+ def __init__(self, segformer_path: str, max_buffer_frames: int = 30):
87
+ self.segformer_path = segformer_path
88
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
89
+
90
+ # Model loading
91
+ self.model = None
92
+ self.processor = None
93
+ self._load_segformer()
94
+
95
+ # GMM components
96
+ self.gmm_model = None
97
+ self.gmm_heatmap = None
98
+ self.table_mask = None
99
+
100
+ # Live streaming state
101
+ self.frame_buffer = CircularFrameBuffer(max_frames=max_buffer_frames)
102
+ self.input_queue = Queue(maxsize=5)
103
+ self.processing_thread = None
104
+ self.is_running = False
105
+
106
+ # Frame sequence tracking
107
+ self.frame_sequence = 0
108
+ self.frame_lock = Lock()
109
+
110
+ # State management
111
+ self.detection_frames_count = 0
112
+ self.no_detection_frames_count = 0
113
+ self.cleaning_active = False
114
+ self.cleaning_start_threshold = 4
115
+ self.cleaning_stop_threshold = 12
116
+
117
+ # Performance tracking
118
+ self.frame_times = deque(maxlen=30)
119
+ self.last_frame_time = time.time()
120
+
121
+ # Optimization flags
122
+ self.skip_segformer_every_n_frames = 2
123
+ self.segformer_skip_counter = 0
124
+ self.last_cloth_mask = None
125
+
126
+ # Visualization settings
127
+ self.show_cloth_detection = True
128
+ self.erasure_radius_factor = 0.2
129
+ self.gaussian_sigma_factor = 0.8
130
+
131
+ self.tracker = None
132
+ self.track_trajectories = {}
133
+ self.max_trajectory_length = 40
134
+ self.track_colors = {}
135
+
136
+ # Alert manager - ADD THIS
137
+ self.alert_manager = None
138
+ self.current_camera_name = "Default Camera"
139
+
140
+ logger.info(f"Live Monitor initialized on {self.device}")
141
+
142
+ def _load_segformer(self):
143
+ """Load SegFormer model."""
144
+ try:
145
+ self.model = SegformerForSemanticSegmentation.from_pretrained(self.segformer_path)
146
+ self.processor = SegformerImageProcessor(do_reduce_labels=False)
147
+ self.model.to(self.device)
148
+ self.model.eval()
149
+ logger.info(f"SegFormer loaded on {self.device}")
150
+ except Exception as e:
151
+ logger.error(f"Failed to load SegFormer: {e}")
152
+
153
+ def _init_tracker(self):
154
+ """Lazy-init tracker."""
155
+ if self.tracker is None:
156
+ from deep_sort_realtime.deepsort_tracker import DeepSort
157
+ self.tracker = DeepSort(
158
+ max_age=15,
159
+ n_init=2,
160
+ nms_max_overlap=0.7,
161
+ max_cosine_distance=0.4,
162
+ nn_budget=50,
163
+ embedder="mobilenet",
164
+ half=True,
165
+ embedder_gpu=torch.cuda.is_available()
166
+ )
167
+
168
+ def load_gmm_model(self, gmm_path: str) -> bool:
169
+ """Load GMM model."""
170
+ try:
171
+ from GMM import GMM
172
+ self.gmm_model = GMM.load_model(gmm_path)
173
+ if self.gmm_model.img_shape:
174
+ h, w = self.gmm_model.img_shape[:2]
175
+ self.gmm_heatmap = np.zeros((h, w), dtype=np.float32)
176
+ logger.info("GMM model loaded")
177
+ return True
178
+ except Exception as e:
179
+ logger.error(f"Failed to load GMM: {e}")
180
+ return False
181
+
182
+ def load_table_mask(self, mask_path: str) -> bool:
183
+ """Load table mask."""
184
+ try:
185
+ mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
186
+ self.table_mask = (mask > 128).astype(np.uint8)
187
+ logger.info(f"Table mask loaded: {mask.shape}")
188
+ return True
189
+ except Exception as e:
190
+ logger.error(f"Failed to load mask: {e}")
191
+ return False
192
+
193
+ def add_frame(self, frame: np.ndarray) -> None:
194
+ """Add incoming frame (non-blocking)."""
195
+ try:
196
+ self.input_queue.put_nowait(frame)
197
+ except:
198
+ pass
199
+
200
+ def start_processing(self) -> None:
201
+ """Start background processing."""
202
+ if self.is_running:
203
+ return
204
+ self.is_running = True
205
+ self.processing_thread = Thread(target=self._process_loop, daemon=True)
206
+ self.processing_thread.start()
207
+ logger.info("Processing thread started")
208
+
209
+ def stop_processing(self) -> None:
210
+ """Stop processing."""
211
+ self.is_running = False
212
+ if self.processing_thread:
213
+ self.processing_thread.join(timeout=5)
214
+ self.frame_buffer.clear()
215
+ logger.info("Processing stopped")
216
+
217
+ def _get_next_sequence_id(self) -> int:
218
+ """Thread-safe sequence ID."""
219
+ with self.frame_lock:
220
+ self.frame_sequence += 1
221
+ return self.frame_sequence
222
+
223
+ def _process_loop(self) -> None:
224
+ """Main processing loop."""
225
+ while self.is_running:
226
+ try:
227
+ frame = self.input_queue.get(timeout=1)
228
+ seq_id = self._get_next_sequence_id()
229
+
230
+ frame = self._resize_frame(frame, target_width=1024)
231
+ cloth_mask = self._detect_cloth_fast(frame)
232
+ cleaning_status = self._update_cleaning_status(cloth_mask)
233
+
234
+ tracks = None
235
+ if self.cleaning_active:
236
+ self._init_tracker()
237
+ tracks = self._track_cloth(frame, cloth_mask)
238
+
239
+ self._update_gmm_fast(frame, cloth_mask, tracks)
240
+ viz_frame = self._create_visualization(frame, cloth_mask, tracks, cleaning_status)
241
+ self.frame_buffer.add_frame(viz_frame, seq_id)
242
+
243
+ elapsed = time.time() - self.last_frame_time
244
+ self.frame_times.append(elapsed)
245
+ self.last_frame_time = time.time()
246
+
247
+ if seq_id % 30 == 0:
248
+ avg_time = np.mean(self.frame_times)
249
+ fps = 1.0 / avg_time if avg_time > 0 else 0
250
+ logger.info(f"Seq {seq_id} | {fps:.1f} FPS | {cleaning_status}")
251
+
252
+ except Exception as e:
253
+ logger.error(f"Processing error: {e}")
254
+ continue
255
+
256
+ def _resize_frame(self, frame: np.ndarray, target_width: int = 1024) -> np.ndarray:
257
+ """Resize frame."""
258
+ h, w = frame.shape[:2]
259
+ if w > target_width:
260
+ scale = target_width / w
261
+ new_h = int(h * scale)
262
+ return cv2.resize(frame, (target_width, new_h))
263
+ return frame
264
+
265
+ def _detect_cloth_fast(self, frame: np.ndarray) -> np.ndarray:
266
+ """Fast cloth detection with skipping."""
267
+ if self.model is None:
268
+ return np.zeros((frame.shape[0], frame.shape[1]), dtype=np.uint8)
269
+
270
+ self.segformer_skip_counter += 1
271
+ if self.segformer_skip_counter < self.skip_segformer_every_n_frames:
272
+ if self.last_cloth_mask is not None:
273
+ return self.last_cloth_mask
274
+ return np.zeros((frame.shape[0], frame.shape[1]), dtype=np.uint8)
275
+
276
+ self.segformer_skip_counter = 0
277
+
278
+ try:
279
+ height, width = frame.shape[:2]
280
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
281
+ pil_image = Image.fromarray(frame_rgb)
282
+
283
+ with torch.no_grad():
284
+ inputs = self.processor(images=pil_image, return_tensors="pt")
285
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
286
+ outputs = self.model(**inputs)
287
+ logits = outputs.logits
288
+
289
+ upsampled = torch.nn.functional.interpolate(
290
+ logits, size=(height, width), mode="bilinear", align_corners=False
291
+ )
292
+
293
+ cloth_mask = (upsampled.argmax(dim=1)[0].cpu().numpy() == 1).astype(np.uint8)
294
+
295
+ if self.table_mask is not None:
296
+ if self.table_mask.shape != cloth_mask.shape:
297
+ table_resized = cv2.resize(self.table_mask, (width, height))
298
+ else:
299
+ table_resized = self.table_mask
300
+ cloth_mask = cloth_mask * table_resized
301
+
302
+ self.last_cloth_mask = cloth_mask
303
+ return cloth_mask
304
+
305
+ except Exception as e:
306
+ logger.error(f"Cloth detection error: {e}")
307
+ return np.zeros((frame.shape[0], frame.shape[1]), dtype=np.uint8)
308
+
309
+ def _track_cloth(self, frame: np.ndarray, cloth_mask: np.ndarray) -> list:
310
+ """Fast tracking."""
311
+ if self.tracker is None:
312
+ return []
313
+
314
+ try:
315
+ contours, _ = cv2.findContours(cloth_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
316
+ detections = []
317
+
318
+ for contour in contours:
319
+ area = cv2.contourArea(contour)
320
+ if area < 150:
321
+ continue
322
+ x, y, w, h = cv2.boundingRect(contour)
323
+ if w > 0 and h > 0:
324
+ detections.append(([x, y, w, h], 0.95, 'cloth'))
325
+
326
+ if not detections:
327
+ return []
328
+
329
+ tracks = self.tracker.update_tracks(detections, frame=frame)
330
+
331
+ height, width = frame.shape[:2]
332
+ for track in tracks:
333
+ if not track.is_confirmed():
334
+ continue
335
+
336
+ track_id = track.track_id
337
+ bbox = track.to_ltrb()
338
+ cx = int((bbox[0] + bbox[2]) / 2)
339
+ cy = int((bbox[1] + bbox[3]) / 2)
340
+
341
+ if 0 <= cx < width and 0 <= cy < height:
342
+ if track_id not in self.track_trajectories:
343
+ self.track_trajectories[track_id] = deque(maxlen=self.max_trajectory_length)
344
+ self.track_colors[track_id] = (255, 255, 0)
345
+ self.track_trajectories[track_id].append((cx, cy))
346
+
347
+ active_ids = {track.track_id for track in tracks if track.is_confirmed()}
348
+ dead_ids = set(self.track_trajectories.keys()) - active_ids
349
+ for dead_id in dead_ids:
350
+ self.track_trajectories.pop(dead_id, None)
351
+ self.track_colors.pop(dead_id, None)
352
+
353
+ return tracks
354
+
355
+ except Exception as e:
356
+ logger.error(f"Tracking error: {e}")
357
+ return []
358
+
359
+ def _update_gmm_fast(self, frame: np.ndarray, cloth_mask: np.ndarray, tracks: list) -> None:
360
+ """Lightweight GMM update."""
361
+ if self.gmm_model is None:
362
+ return
363
+
364
+ try:
365
+ height, width = frame.shape[:2]
366
+ table_mask = None
367
+ if self.table_mask is not None:
368
+ if self.table_mask.shape != (height, width):
369
+ table_mask = cv2.resize(self.table_mask, (width, height))
370
+ else:
371
+ table_mask = self.table_mask
372
+
373
+ _, self.gmm_heatmap = self.gmm_model.infer(
374
+ frame, heatmap=self.gmm_heatmap,
375
+ alpha_start=0.008, alpha_end=0.0004,
376
+ table_mask=table_mask
377
+ )
378
+
379
+ if self.cleaning_active and tracks:
380
+ for track in tracks:
381
+ if not track.is_confirmed():
382
+ continue
383
+
384
+ track_id = track.track_id
385
+ if track_id not in self.track_trajectories:
386
+ continue
387
+
388
+ trajectory = list(self.track_trajectories[track_id])
389
+ if len(trajectory) < 2:
390
+ continue
391
+
392
+ bbox = track.to_ltrb()
393
+ w = bbox[2] - bbox[0]
394
+ h = bbox[3] - bbox[1]
395
+
396
+ radius = int(min(w, h) * self.erasure_radius_factor)
397
+ radius = max(radius, 12)
398
+
399
+ if radius <= 0 or w <= 0 or h <= 0:
400
+ continue
401
+
402
+ for i in range(len(trajectory) - 1):
403
+ self._erase_at_point(trajectory[i], radius, table_mask)
404
+
405
+ except Exception as e:
406
+ logger.error(f"GMM update error: {e}")
407
+
408
+ def _erase_at_point(self, point: tuple, radius: int, table_mask: np.ndarray) -> None:
409
+ """Fast point-based erasure."""
410
+ if self.gmm_heatmap is None or radius <= 0:
411
+ return
412
+
413
+ x, y = point
414
+ height, width = self.gmm_heatmap.shape
415
+
416
+ y_min = max(0, y - radius)
417
+ y_max = min(height, y + radius)
418
+ x_min = max(0, x - radius)
419
+ x_max = min(width, x + radius)
420
+
421
+ if y_min >= y_max or x_min >= x_max:
422
+ return
423
+
424
+ y_indices, x_indices = np.ogrid[y_min:y_max, x_min:x_max]
425
+ distance_sq = (x_indices - x)**2 + (y_indices - y)**2
426
+
427
+ gaussian = np.exp(-distance_sq / (2 * (radius * self.gaussian_sigma_factor)**2))
428
+
429
+ if table_mask is not None:
430
+ gaussian = gaussian * table_mask[y_min:y_max, x_min:x_max]
431
+
432
+ decay = 0.025 * gaussian
433
+ self.gmm_heatmap[y_min:y_max, x_min:x_max] = np.maximum(
434
+ 0, self.gmm_heatmap[y_min:y_max, x_min:x_max] - decay
435
+ )
436
+
437
+ def _update_cleaning_status(self, cloth_mask: np.ndarray) -> str:
438
+ """Update cleaning status."""
439
+ has_cloth = np.sum(cloth_mask) > 100
440
+
441
+ if has_cloth:
442
+ self.detection_frames_count += 1
443
+ self.no_detection_frames_count = 0
444
+ else:
445
+ self.no_detection_frames_count += 1
446
+ self.detection_frames_count = 0
447
+
448
+ if not self.cleaning_active and self.detection_frames_count >= self.cleaning_start_threshold:
449
+ self.cleaning_active = True
450
+ return "CLEANING STARTED"
451
+ elif self.cleaning_active and self.no_detection_frames_count >= self.cleaning_stop_threshold:
452
+ self.cleaning_active = False
453
+ return "CLEANING STOPPED"
454
+
455
+ return "CLEANING ACTIVE" if self.cleaning_active else "NO CLEANING"
456
+
457
+ def _create_visualization(self, frame: np.ndarray, cloth_mask: np.ndarray,
458
+ tracks: list, cleaning_status: str) -> np.ndarray:
459
+ """Create fast visualization."""
460
+ result = frame.copy()
461
+
462
+ if np.sum(cloth_mask) > 0:
463
+ overlay = result.copy()
464
+ cloth_pixels = cloth_mask > 0
465
+ overlay[cloth_pixels] = [0, 255, 0]
466
+ result[cloth_pixels] = cv2.addWeighted(
467
+ frame[cloth_pixels], 0.7, overlay[cloth_pixels], 0.3, 0
468
+ )
469
+
470
+ if self.gmm_heatmap is not None and self.gmm_heatmap.max() > 0:
471
+ height, width = result.shape[:2]
472
+ heatmap_resized = cv2.resize(self.gmm_heatmap, (width, height))
473
+ heatmap_colored = cv2.applyColorMap(
474
+ (heatmap_resized * 255).astype(np.uint8), cv2.COLORMAP_JET
475
+ )
476
+ significant = heatmap_resized > 0.1
477
+ result[significant] = cv2.addWeighted(
478
+ frame[significant], 0.6, heatmap_colored[significant], 0.4, 0
479
+ )
480
+
481
+ if tracks:
482
+ for track in tracks:
483
+ if track.is_confirmed():
484
+ bbox = track.to_ltrb()
485
+ cx, cy = int((bbox[0] + bbox[2])/2), int((bbox[1] + bbox[3])/2)
486
+ cv2.circle(result, (cx, cy), 4, (0, 0, 255), -1)
487
+
488
+ status_color = (0, 255, 0) if "ACTIVE" in cleaning_status else (150, 150, 150)
489
+ cv2.putText(result, cleaning_status, (20, 40),
490
+ cv2.FONT_HERSHEY_SIMPLEX, 0.7, status_color, 2)
491
+
492
+ return result
493
+
494
+ def get_latest_frame(self) -> np.ndarray:
495
+ """Get latest processed frame."""
496
+ frame, _ = self.frame_buffer.get_latest()
497
+ return frame
498
+
499
+ def get_stats(self) -> dict:
500
+ """Get stats."""
501
+ with self.frame_buffer.lock:
502
+ avg_time = np.mean(self.frame_times) if len(self.frame_times) > 0 else 0.033
503
+ fps = 1.0 / avg_time if avg_time > 0 else 0
504
+ return {
505
+ "buffered_frames": len(self.frame_buffer.frames),
506
+ "avg_fps": fps,
507
+ "queue_size": self.input_queue.qsize(),
508
+ "is_running": self.is_running
509
+ }
510
+
511
+
512
+ # ==================== FASTAPI APP ====================
513
+
514
+ app = FastAPI(title="Hygiene Monitor Live Stream", version="1.0.0")
515
+
516
+ # Active streams: {stream_id: {"monitor": LiveHygieneMonitor, "cap": VideoCapture, "thread": Thread}}
517
+ active_streams = {}
518
+ streams_lock = Lock()
519
+
520
+
521
+ def _get_model_files(camera_path: str) -> tuple:
522
+ """Extract GMM and mask paths from camera directory."""
523
+ if not os.path.isdir(camera_path):
524
+ raise ValueError(f"Camera path not found: {camera_path}")
525
+
526
+ gmm_path = os.path.join(camera_path, "gmm_model.joblib")
527
+ mask_path = os.path.join(camera_path, "mask.png")
528
+
529
+ if not os.path.exists(gmm_path):
530
+ raise ValueError(f"GMM model not found: {gmm_path}")
531
+ if not os.path.exists(mask_path):
532
+ raise ValueError(f"Mask not found: {mask_path}")
533
+
534
+ return gmm_path, mask_path
535
+
536
+
537
+ def _stream_worker(stream_id: str, rtmp_url: str, gmm_path: str, mask_path: str):
538
+ """Background worker for streaming."""
539
+ try:
540
+ monitor = LiveHygieneMonitor(
541
+ segformer_path="models/segformer_model",
542
+ max_buffer_frames=30
543
+ )
544
+
545
+ if not monitor.load_gmm_model(gmm_path):
546
+ logger.error(f"[{stream_id}] Failed to load GMM model")
547
+ return
548
+
549
+ if not monitor.load_table_mask(mask_path):
550
+ logger.error(f"[{stream_id}] Failed to load mask")
551
+ return
552
+
553
+ # === INITIALIZE ALERT MANAGER - ADD THIS ===
554
+ webhook_url = os.getenv("DISCORD_WEBHOOK_URL") # From environment
555
+ if webhook_url:
556
+ monitor.alert_manager = DiscordAlertManager(webhook_url=webhook_url)
557
+ monitor.current_camera_name = stream_id # Or pass from request
558
+ logger.info(f"[{stream_id}] Alert manager initialized")
559
+
560
+ monitor.start_processing()
561
+
562
+ cap = cv2.VideoCapture(rtmp_url)
563
+ if not cap.isOpened():
564
+ logger.error(f"[{stream_id}] Failed to connect to RTMP: {rtmp_url}")
565
+ monitor.stop_processing()
566
+ return
567
+
568
+ # Update active stream
569
+ with streams_lock:
570
+ if stream_id in active_streams:
571
+ active_streams[stream_id]["monitor"] = monitor
572
+ active_streams[stream_id]["cap"] = cap
573
+ active_streams[stream_id]["connected"] = True
574
+
575
+ frame_count = 0
576
+ logger.info(f"[{stream_id}] Connected to {rtmp_url}")
577
+
578
+ while True:
579
+ with streams_lock:
580
+ if stream_id not in active_streams or not active_streams[stream_id]["running"]:
581
+ break
582
+
583
+ ret, frame = cap.read()
584
+ if not ret:
585
+ logger.warning(f"[{stream_id}] RTMP connection lost, reconnecting...")
586
+ cap.release()
587
+ time.sleep(2)
588
+ cap = cv2.VideoCapture(rtmp_url)
589
+ continue
590
+
591
+ monitor.add_frame(frame)
592
+ frame_count += 1
593
+
594
+ if frame_count % 100 == 0:
595
+ stats = monitor.get_stats()
596
+ logger.info(f"[{stream_id}] Frames: {frame_count}, FPS: {stats['avg_fps']:.1f}")
597
+
598
+ except Exception as e:
599
+ logger.error(f"[{stream_id}] Stream error: {e}")
600
+
601
+ finally:
602
+ with streams_lock:
603
+ if stream_id in active_streams:
604
+ if active_streams[stream_id]["cap"]:
605
+ active_streams[stream_id]["cap"].release()
606
+ if active_streams[stream_id]["monitor"]:
607
+ active_streams[stream_id]["monitor"].stop_processing()
608
+ active_streams[stream_id]["connected"] = False
609
+
610
+ logger.info(f"[{stream_id}] Stream closed")
611
+
612
+
613
+ # ==================== ENDPOINTS ====================
614
+
615
+ @app.post("/stream/start")
616
+ async def start_stream(request: StreamStartRequest):
617
+ """Start a new live stream."""
618
+ stream_id = f"stream_{int(time.time() * 1000)}"
619
+
620
+ try:
621
+ # Extract model files from camera path
622
+ gmm_path, mask_path = _get_model_files(request.camera_path)
623
+
624
+ # Create stream entry
625
+ with streams_lock:
626
+ active_streams[stream_id] = {
627
+ "running": True,
628
+ "connected": False,
629
+ "monitor": None,
630
+ "cap": None,
631
+ "thread": None,
632
+ "camera_path": request.camera_path
633
+ }
634
+
635
+ # Start background worker thread
636
+ thread = Thread(
637
+ target=_stream_worker,
638
+ args=(stream_id, request.rtmp_input_url, gmm_path, mask_path),
639
+ daemon=True
640
+ )
641
+ thread.start()
642
+
643
+ with streams_lock:
644
+ active_streams[stream_id]["thread"] = thread
645
+
646
+ logger.info(f"Stream {stream_id} started")
647
+ return {
648
+ "stream_id": stream_id,
649
+ "status": "starting",
650
+ "message": f"Stream {stream_id} is starting, will connect to {request.rtmp_input_url}"
651
+ }
652
+
653
+ except Exception as e:
654
+ logger.error(f"Failed to start stream: {e}")
655
+ raise HTTPException(status_code=400, detail=str(e))
656
+
657
+
658
+ @app.post("/stream/stop")
659
+ async def stop_stream(request: StreamStopRequest):
660
+ """Stop a live stream."""
661
+ stream_id = request.stream_id
662
+
663
+ with streams_lock:
664
+ if stream_id not in active_streams:
665
+ raise HTTPException(status_code=404, detail=f"Stream {stream_id} not found")
666
+
667
+ active_streams[stream_id]["running"] = False
668
+
669
+ logger.info(f"Stream {stream_id} stop requested")
670
+ return {"stream_id": stream_id, "status": "stopping"}
671
+
672
+
673
+ @app.get("/stream/status/{stream_id}")
674
+ async def get_stream_status(stream_id: str):
675
+ """Get stream status."""
676
+ with streams_lock:
677
+ if stream_id not in active_streams:
678
+ raise HTTPException(status_code=404, detail=f"Stream {stream_id} not found")
679
+
680
+ stream_data = active_streams[stream_id]
681
+ monitor = stream_data["monitor"]
682
+
683
+ stats = monitor.get_stats() if monitor else {}
684
+
685
+ return {
686
+ "stream_id": stream_id,
687
+ "connected": stream_data["connected"],
688
+ "running": stream_data["running"],
689
+ "camera_path": stream_data["camera_path"],
690
+ "fps": stats.get("avg_fps", 0),
691
+ "buffered_frames": stats.get("buffered_frames", 0),
692
+ "queue_size": stats.get("queue_size", 0)
693
+ }
694
+
695
+
696
+ @app.get("/stream/video/{stream_id}")
697
+ async def stream_video(stream_id: str):
698
+ """Stream video frames via MJPEG."""
699
+ with streams_lock:
700
+ if stream_id not in active_streams:
701
+ raise HTTPException(status_code=404, detail=f"Stream {stream_id} not found")
702
+
703
+ monitor = active_streams[stream_id]["monitor"]
704
+
705
+ if not monitor:
706
+ raise HTTPException(status_code=503, detail="Monitor not ready")
707
+
708
+ async def frame_generator():
709
+ while True:
710
+ with streams_lock:
711
+ if stream_id not in active_streams or not active_streams[stream_id]["running"]:
712
+ break
713
+
714
+ frame = monitor.get_latest_frame()
715
+ if frame is not None:
716
+ _, buffer = cv2.imencode('.jpg', frame, [cv2.IMWRITE_JPEG_QUALITY, 80])
717
+ yield (b'--frame\r\n'
718
+ b'Content-Type: image/jpeg\r\n'
719
+ b'Content-Length: ' + str(len(buffer)).encode() + b'\r\n\r\n'
720
+ + buffer.tobytes() + b'\r\n')
721
+ else:
722
+ await asyncio.sleep(0.01)
723
+
724
+ return StreamingResponse(
725
+ frame_generator(),
726
+ media_type="multipart/x-mixed-replace; boundary=frame"
727
+ )
728
+
729
+
730
+ @app.get("/streams")
731
+ async def list_streams():
732
+ """List all active streams."""
733
+ with streams_lock:
734
+ streams_list = []
735
+ for stream_id, data in active_streams.items():
736
+ monitor = data["monitor"]
737
+ stats = monitor.get_stats() if monitor else {}
738
+
739
+ streams_list.append({
740
+ "stream_id": stream_id,
741
+ "connected": data["connected"],
742
+ "running": data["running"],
743
+ "camera_path": data["camera_path"],
744
+ "fps": stats.get("avg_fps", 0),
745
+ "buffered_frames": stats.get("buffered_frames", 0)
746
+ })
747
+
748
+ return {"total_streams": len(streams_list), "streams": streams_list}
749
+
750
+
751
+ @app.post("/stream/restart/{stream_id}")
752
+ async def restart_stream(stream_id: str):
753
+ """Restart a stream."""
754
+ with streams_lock:
755
+ if stream_id not in active_streams:
756
+ raise HTTPException(status_code=404, detail=f"Stream {stream_id} not found")
757
+
758
+ active_streams[stream_id]["running"] = False
759
+
760
+ await asyncio.sleep(2)
761
+
762
+ with streams_lock:
763
+ data = active_streams[stream_id]
764
+ data["running"] = True
765
+
766
+ return {"stream_id": stream_id, "status": "restarting"}
767
+
768
+ @app.post("/camera/extract_frame")
769
+ async def extract_frame_from_rtmp(request: dict):
770
+ """
771
+ Extract a single frame from RTMP stream for corner selection.
772
+
773
+ Request body:
774
+ {
775
+ "rtmp_url": "rtmp://192.168.1.100:1935/live/kitchen",
776
+ "camera_name": "kitchen"
777
+ }
778
+
779
+ Returns:
780
+ {
781
+ "success": true,
782
+ "frame_base64": "base64_encoded_image",
783
+ "frame_dimensions": {"width": 1920, "height": 1080}
784
+ }
785
+ """
786
+ try:
787
+ rtmp_url = request.get("rtmp_url")
788
+ camera_name = request.get("camera_name")
789
+
790
+ if not rtmp_url or not camera_name:
791
+ raise HTTPException(status_code=400, detail="Missing rtmp_url or camera_name")
792
+
793
+ # Connect to RTMP stream
794
+ cap = cv2.VideoCapture(rtmp_url)
795
+ if not cap.isOpened():
796
+ raise HTTPException(status_code=400, detail=f"Failed to connect to RTMP: {rtmp_url}")
797
+
798
+ # Read first frame
799
+ ret, frame = cap.read()
800
+ cap.release()
801
+
802
+ if not ret:
803
+ raise HTTPException(status_code=400, detail="Failed to read frame from RTMP stream")
804
+ import base64
805
+ # Convert frame to base64 for frontend display
806
+ _, buffer = cv2.imencode('.jpg', frame, [cv2.IMWRITE_JPEG_QUALITY, 95])
807
+ frame_base64 = base64.b64encode(buffer).decode('utf-8')
808
+
809
+ # Store frame temporarily for training (optional - could store in memory cache)
810
+ temp_dir = "temp_frames"
811
+ os.makedirs(temp_dir, exist_ok=True)
812
+ temp_frame_path = os.path.join(temp_dir, f"{camera_name}_reference.jpg")
813
+ cv2.imwrite(temp_frame_path, frame)
814
+
815
+ return {
816
+ "success": True,
817
+ "frame_base64": frame_base64,
818
+ "frame_dimensions": {
819
+ "width": frame.shape[1],
820
+ "height": frame.shape[0]
821
+ },
822
+ "temp_frame_path": temp_frame_path
823
+ }
824
+
825
+ except Exception as e:
826
+ logger.error(f"Extract frame error: {e}")
827
+ raise HTTPException(status_code=500, detail=str(e))
828
+
829
+
830
+ @app.post("/camera/train_gmm")
831
+ async def train_gmm_from_rtmp(request: dict):
832
+ """
833
+ Train GMM model from RTMP stream using N corner points (minimum 4).
834
+
835
+ Request body:
836
+ {
837
+ "rtmp_url": "rtmp://192.168.1.100:1935/live/kitchen",
838
+ "camera_name": "kitchen",
839
+ "corner_points": [
840
+ {"x": 100, "y": 50},
841
+ {"x": 400, "y": 45},
842
+ {"x": 700, "y": 55},
843
+ {"x": 800, "y": 60},
844
+ {"x": 850, "y": 300},
845
+ {"x": 850, "y": 600},
846
+ {"x": 400, "y": 620},
847
+ {"x": 50, "y": 580},
848
+ {"x": 45, "y": 300}
849
+ ], // Can be 4+ points for curved tables
850
+ "max_frames": 250,
851
+ "use_perspective_warp": false // NEW: Set false for non-rectangular tables
852
+ }
853
+ """
854
+ try:
855
+ rtmp_url = request.get("rtmp_url")
856
+ camera_name = request.get("camera_name")
857
+ corner_points = request.get("corner_points")
858
+ max_frames = request.get("max_frames", 250)
859
+ use_perspective_warp = request.get("use_perspective_warp", False) # NEW
860
+
861
+ # Validation
862
+ if not rtmp_url or not camera_name or not corner_points:
863
+ raise HTTPException(status_code=400, detail="Missing required parameters")
864
+
865
+ if len(corner_points) < 4:
866
+ raise HTTPException(status_code=400, detail="Minimum 4 corner points required")
867
+
868
+ logger.info(f"Starting GMM training for camera: {camera_name} with {len(corner_points)} points")
869
+
870
+ # ===== STEP 1: Connect to RTMP and capture frames =====
871
+ cap = cv2.VideoCapture(rtmp_url)
872
+ if not cap.isOpened():
873
+ raise HTTPException(status_code=400, detail=f"Failed to connect to RTMP: {rtmp_url}")
874
+
875
+ ret, first_frame = cap.read()
876
+ if not ret:
877
+ cap.release()
878
+ raise HTTPException(status_code=400, detail="Failed to read from RTMP stream")
879
+
880
+ h, w = first_frame.shape[:2]
881
+
882
+ # ===== STEP 2: Create polygon mask from N points =====
883
+ pts_polygon = np.array([
884
+ [point['x'], point['y']] for point in corner_points
885
+ ], dtype=np.int32)
886
+
887
+ # Create binary mask for the table area
888
+ table_mask = np.zeros((h, w), dtype=np.uint8)
889
+ cv2.fillPoly(table_mask, [pts_polygon], 255)
890
+
891
+ # ===== STEP 3: Decide transformation strategy =====
892
+ import tempfile
893
+ temp_dir = tempfile.mkdtemp()
894
+ frame_count = 0
895
+
896
+ if use_perspective_warp and len(corner_points) == 4:
897
+ # ===== STRATEGY A: Perspective warp (rectangular tables only) =====
898
+ logger.info("Using perspective warp for rectangular table")
899
+
900
+ pts_src = np.array([
901
+ [corner_points[0]['x'], corner_points[0]['y']],
902
+ [corner_points[1]['x'], corner_points[1]['y']],
903
+ [corner_points[2]['x'], corner_points[2]['y']],
904
+ [corner_points[3]['x'], corner_points[3]['y']]
905
+ ], dtype=np.float32)
906
+
907
+ pts_dst = np.array([
908
+ [0, 0], [w, 0], [w, h], [0, h]
909
+ ], dtype=np.float32)
910
+
911
+ matrix = cv2.getPerspectiveTransform(pts_src, pts_dst)
912
+
913
+ # Capture and warp frames
914
+ cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
915
+ while frame_count < max_frames:
916
+ ret, frame = cap.read()
917
+ if not ret:
918
+ break
919
+
920
+ warped = cv2.warpPerspective(frame, matrix, (w, h))
921
+ frame_path = os.path.join(temp_dir, f'b{frame_count:05d}.png')
922
+ cv2.imwrite(frame_path, warped)
923
+ frame_count += 1
924
+
925
+ if frame_count % 50 == 0:
926
+ logger.info(f"Captured {frame_count}/{max_frames} frames")
927
+
928
+ # For warped images, mask should be full frame (already aligned)
929
+ final_mask = np.ones((h, w), dtype=np.uint8) * 255
930
+
931
+ else:
932
+ # ===== STRATEGY B: Direct masking (curved/complex tables) =====
933
+ logger.info(f"Using direct masking for {len(corner_points)}-point polygon (curved table)")
934
+
935
+ # Capture frames WITHOUT warping, apply mask during inference
936
+ cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
937
+ while frame_count < max_frames:
938
+ ret, frame = cap.read()
939
+ if not ret:
940
+ break
941
+
942
+ # Apply mask to frame (zero out outside table area)
943
+ masked_frame = cv2.bitwise_and(frame, frame, mask=table_mask)
944
+
945
+ frame_path = os.path.join(temp_dir, f'b{frame_count:05d}.png')
946
+ cv2.imwrite(frame_path, masked_frame)
947
+ frame_count += 1
948
+
949
+ if frame_count % 50 == 0:
950
+ logger.info(f"Captured {frame_count}/{max_frames} frames")
951
+
952
+ # Use original polygon mask
953
+ final_mask = table_mask
954
+
955
+ cap.release()
956
+
957
+ if frame_count == 0:
958
+ raise HTTPException(status_code=400, detail="No frames captured")
959
+
960
+ logger.info(f"Captured {frame_count} frames, starting GMM training...")
961
+
962
+ # ===== STEP 4: Train GMM =====
963
+ from GMM import GMM
964
+ gmm = GMM(temp_dir, frame_count, alpha=0.05)
965
+ gmm.train(K=4)
966
+ logger.info("GMM training complete")
967
+
968
+ # ===== STEP 5: Save artifacts =====
969
+ camera_path = os.path.join("models", camera_name)
970
+ os.makedirs(camera_path, exist_ok=True)
971
+
972
+ # 1. Save GMM model
973
+ gmm_path = os.path.join(camera_path, "gmm_model.joblib")
974
+ gmm.save_model(gmm_path)
975
+
976
+ # 2. Save mask (polygon-based, not rectangular)
977
+ mask_path = os.path.join(camera_path, "mask.png")
978
+ cv2.imwrite(mask_path, final_mask)
979
+ logger.info(f"Saved {len(corner_points)}-point polygon mask to {mask_path}")
980
+
981
+ # 3. Create thumbnail with polygon overlay
982
+ thumb_frame = first_frame.copy()
983
+
984
+ # Draw filled polygon with transparency
985
+ overlay = thumb_frame.copy()
986
+ cv2.fillPoly(overlay, [pts_polygon], (0, 255, 0))
987
+ cv2.addWeighted(thumb_frame, 0.7, overlay, 0.3, 0, thumb_frame)
988
+
989
+ # Draw polygon border
990
+ cv2.polylines(thumb_frame, [pts_polygon], True, (0, 255, 0), 3)
991
+
992
+ # Draw corner points with numbers
993
+ colors = [
994
+ (255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0),
995
+ (255, 0, 255), (0, 255, 255), (128, 0, 128), (255, 128, 0)
996
+ ]
997
+
998
+ for i, point in enumerate(corner_points):
999
+ x, y = point['x'], point['y']
1000
+ color = colors[i % len(colors)]
1001
+
1002
+ cv2.circle(thumb_frame, (x, y), 8, color, -1)
1003
+ cv2.circle(thumb_frame, (x, y), 10, (255, 255, 255), 2)
1004
+
1005
+ # Point number
1006
+ cv2.putText(thumb_frame, str(i+1), (x+15, y),
1007
+ cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2)
1008
+
1009
+ # Camera name label
1010
+ cv2.putText(thumb_frame, camera_name, (30, 50),
1011
+ cv2.FONT_HERSHEY_DUPLEX, 1.5, (255, 255, 255), 3)
1012
+ cv2.putText(thumb_frame, camera_name, (30, 50),
1013
+ cv2.FONT_HERSHEY_DUPLEX, 1.5, (0, 255, 0), 2)
1014
+
1015
+ # Add point count indicator
1016
+ cv2.putText(thumb_frame, f"{len(corner_points)} points", (30, 90),
1017
+ cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 2)
1018
+
1019
+ thumb_path = os.path.join(camera_path, "thumb.png")
1020
+ cv2.imwrite(thumb_path, thumb_frame)
1021
+
1022
+ # 4. Save polygon metadata (NEW - for reconstruction)
1023
+ metadata = {
1024
+ "camera_name": camera_name,
1025
+ "num_points": len(corner_points),
1026
+ "corner_points": corner_points,
1027
+ "frame_dimensions": {"width": w, "height": h},
1028
+ "use_perspective_warp": use_perspective_warp,
1029
+ "training_date": datetime.now().isoformat()
1030
+ }
1031
+
1032
+ import json
1033
+ metadata_path = os.path.join(camera_path, "metadata.json")
1034
+ with open(metadata_path, 'w') as f:
1035
+ json.dump(metadata, f, indent=2)
1036
+
1037
+ logger.info(f"Saved metadata to {metadata_path}")
1038
+
1039
+ # Cleanup
1040
+ import shutil
1041
+ shutil.rmtree(temp_dir)
1042
+
1043
+ logger.info(f"✅ Camera '{camera_name}' training complete with {len(corner_points)}-point polygon!")
1044
+
1045
+ return {
1046
+ "success": True,
1047
+ "camera_name": camera_name,
1048
+ "camera_path": camera_path,
1049
+ "frames_captured": frame_count,
1050
+ "polygon_points": len(corner_points),
1051
+ "use_perspective_warp": use_perspective_warp,
1052
+ "model_files": {
1053
+ "gmm_model": gmm_path,
1054
+ "mask": mask_path,
1055
+ "thumbnail": thumb_path,
1056
+ "metadata": metadata_path
1057
+ }
1058
+ }
1059
+
1060
+ except Exception as e:
1061
+ logger.error(f"GMM training error: {e}")
1062
+ import traceback
1063
+ logger.error(traceback.format_exc())
1064
+ raise HTTPException(status_code=500, detail=str(e))
1065
+
1066
+
1067
+ @app.get("/cameras")
1068
+ async def list_cameras():
1069
+ """
1070
+ List all trained cameras with their metadata.
1071
+
1072
+ Returns:
1073
+ {
1074
+ "cameras": [
1075
+ {
1076
+ "name": "kitchen",
1077
+ "path": "models/kitchen",
1078
+ "thumbnail": "models/kitchen/thumb.png",
1079
+ "has_gmm_model": true,
1080
+ "has_mask": true
1081
+ }
1082
+ ]
1083
+ }
1084
+ """
1085
+ try:
1086
+ cameras = []
1087
+ models_dir = "models"
1088
+
1089
+ if not os.path.exists(models_dir):
1090
+ return {"cameras": []}
1091
+
1092
+ for camera_name in os.listdir(models_dir):
1093
+ camera_path = os.path.join(models_dir, camera_name)
1094
+
1095
+ if not os.path.isdir(camera_path):
1096
+ continue
1097
+
1098
+ gmm_path = os.path.join(camera_path, "gmm_model.joblib")
1099
+ mask_path = os.path.join(camera_path, "mask.png")
1100
+ thumb_path = os.path.join(camera_path, "thumb.png")
1101
+
1102
+ cameras.append({
1103
+ "name": camera_name,
1104
+ "path": camera_path,
1105
+ "thumbnail": thumb_path if os.path.exists(thumb_path) else None,
1106
+ "has_gmm_model": os.path.exists(gmm_path),
1107
+ "has_mask": os.path.exists(mask_path)
1108
+ })
1109
+
1110
+ return {"cameras": cameras}
1111
+
1112
+ except Exception as e:
1113
+ logger.error(f"List cameras error: {e}")
1114
+ raise HTTPException(status_code=500, detail=str(e))
1115
+
1116
+
1117
+ @app.delete("/camera/{camera_name}")
1118
+ async def delete_camera(camera_name: str):
1119
+ """
1120
+ Delete a trained camera and all its files.
1121
+ """
1122
+ try:
1123
+ camera_path = os.path.join("models", camera_name)
1124
+
1125
+ if not os.path.exists(camera_path):
1126
+ raise HTTPException(status_code=404, detail=f"Camera '{camera_name}' not found")
1127
+
1128
+ import shutil
1129
+ shutil.rmtree(camera_path)
1130
+
1131
+ logger.info(f"Deleted camera: {camera_name}")
1132
+
1133
+ return {
1134
+ "success": True,
1135
+ "message": f"Camera '{camera_name}' deleted successfully"
1136
+ }
1137
+
1138
+ except Exception as e:
1139
+ logger.error(f"Delete camera error: {e}")
1140
+ raise HTTPException(status_code=500, detail=str(e))
1141
+
1142
+
1143
+ @app.get("/health")
1144
+ async def health_check():
1145
+ """Health check endpoint."""
1146
+ with streams_lock:
1147
+ stream_count = len(active_streams)
1148
+
1149
+ return {
1150
+ "status": "healthy",
1151
+ "active_streams": stream_count,
1152
+ "timestamp": datetime.now().isoformat()
1153
+ }
1154
+
1155
+
1156
+ if __name__ == "__main__":
1157
  uvicorn.run(app, host="0.0.0.0", port=8000)