Yingtao-Zheng commited on
Commit
357e9dc
·
1 Parent(s): dadceb2

Combined main.py and main_yolo.py to create main_combined.py, where main.py came from Mohamad and main_yolo.py came from UI team. (To be tested)

Browse files
Files changed (2) hide show
  1. main_combined.py +980 -0
  2. main_yolo.py +601 -0
main_combined.py ADDED
@@ -0,0 +1,980 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException, Request
2
+ from fastapi.staticfiles import StaticFiles
3
+ from fastapi.responses import FileResponse
4
+ from fastapi.middleware.cors import CORSMiddleware
5
+ from pydantic import BaseModel
6
+ from typing import Optional, List, Any
7
+ import base64
8
+ import cv2
9
+ import numpy as np
10
+ import aiosqlite
11
+ import json
12
+ from datetime import datetime, timedelta
13
+ import math
14
+ import os
15
+ from pathlib import Path
16
+ from typing import Callable
17
+ import asyncio
18
+ import concurrent.futures
19
+ import threading
20
+
21
+ from aiortc import RTCPeerConnection, RTCSessionDescription, VideoStreamTrack
22
+ from av import VideoFrame
23
+
24
+ from mediapipe.tasks.python.vision import FaceLandmarksConnections
25
+ from ui.pipeline import FaceMeshPipeline, MLPPipeline, HybridFocusPipeline, XGBoostPipeline
26
+ from models.face_mesh import FaceMeshDetector
27
+
28
+ # ================ FACE MESH DRAWING (server-side, for WebRTC) ================
29
+
30
+ _FONT = cv2.FONT_HERSHEY_SIMPLEX
31
+ _CYAN = (255, 255, 0)
32
+ _GREEN = (0, 255, 0)
33
+ _MAGENTA = (255, 0, 255)
34
+ _ORANGE = (0, 165, 255)
35
+ _RED = (0, 0, 255)
36
+ _WHITE = (255, 255, 255)
37
+ _LIGHT_GREEN = (144, 238, 144)
38
+
39
+ _TESSELATION_CONNS = [(c.start, c.end) for c in FaceLandmarksConnections.FACE_LANDMARKS_TESSELATION]
40
+ _CONTOUR_CONNS = [(c.start, c.end) for c in FaceLandmarksConnections.FACE_LANDMARKS_CONTOURS]
41
+ _LEFT_EYEBROW = [70, 63, 105, 66, 107, 55, 65, 52, 53, 46]
42
+ _RIGHT_EYEBROW = [300, 293, 334, 296, 336, 285, 295, 282, 283, 276]
43
+ _NOSE_BRIDGE = [6, 197, 195, 5, 4, 1, 19, 94, 2]
44
+ _LIPS_OUTER = [61, 146, 91, 181, 84, 17, 314, 405, 321, 375, 291, 409, 270, 269, 267, 0, 37, 39, 40, 185, 61]
45
+ _LIPS_INNER = [78, 95, 88, 178, 87, 14, 317, 402, 318, 324, 308, 415, 310, 311, 312, 13, 82, 81, 80, 191, 78]
46
+ _LEFT_EAR_POINTS = [33, 160, 158, 133, 153, 145]
47
+ _RIGHT_EAR_POINTS = [362, 385, 387, 263, 373, 380]
48
+
49
+
50
+ def _lm_px(lm, idx, w, h):
51
+ return (int(lm[idx, 0] * w), int(lm[idx, 1] * h))
52
+
53
+
54
+ def _draw_polyline(frame, lm, indices, w, h, color, thickness):
55
+ for i in range(len(indices) - 1):
56
+ cv2.line(frame, _lm_px(lm, indices[i], w, h), _lm_px(lm, indices[i + 1], w, h), color, thickness, cv2.LINE_AA)
57
+
58
+
59
+ def _draw_face_mesh(frame, lm, w, h):
60
+ """Draw tessellation, contours, eyebrows, nose, lips, eyes, irises, gaze lines."""
61
+ # Tessellation (gray triangular grid, semi-transparent)
62
+ overlay = frame.copy()
63
+ for s, e in _TESSELATION_CONNS:
64
+ cv2.line(overlay, _lm_px(lm, s, w, h), _lm_px(lm, e, w, h), (200, 200, 200), 1, cv2.LINE_AA)
65
+ cv2.addWeighted(overlay, 0.3, frame, 0.7, 0, frame)
66
+ # Contours
67
+ for s, e in _CONTOUR_CONNS:
68
+ cv2.line(frame, _lm_px(lm, s, w, h), _lm_px(lm, e, w, h), _CYAN, 1, cv2.LINE_AA)
69
+ # Eyebrows
70
+ _draw_polyline(frame, lm, _LEFT_EYEBROW, w, h, _LIGHT_GREEN, 2)
71
+ _draw_polyline(frame, lm, _RIGHT_EYEBROW, w, h, _LIGHT_GREEN, 2)
72
+ # Nose
73
+ _draw_polyline(frame, lm, _NOSE_BRIDGE, w, h, _ORANGE, 1)
74
+ # Lips
75
+ _draw_polyline(frame, lm, _LIPS_OUTER, w, h, _MAGENTA, 1)
76
+ _draw_polyline(frame, lm, _LIPS_INNER, w, h, (200, 0, 200), 1)
77
+ # Eyes
78
+ left_pts = np.array([_lm_px(lm, i, w, h) for i in FaceMeshDetector.LEFT_EYE_INDICES], dtype=np.int32)
79
+ cv2.polylines(frame, [left_pts], True, _GREEN, 2, cv2.LINE_AA)
80
+ right_pts = np.array([_lm_px(lm, i, w, h) for i in FaceMeshDetector.RIGHT_EYE_INDICES], dtype=np.int32)
81
+ cv2.polylines(frame, [right_pts], True, _GREEN, 2, cv2.LINE_AA)
82
+ # EAR key points
83
+ for indices in [_LEFT_EAR_POINTS, _RIGHT_EAR_POINTS]:
84
+ for idx in indices:
85
+ cv2.circle(frame, _lm_px(lm, idx, w, h), 3, (0, 255, 255), -1, cv2.LINE_AA)
86
+ # Irises + gaze lines
87
+ for iris_idx, eye_inner, eye_outer in [
88
+ (FaceMeshDetector.LEFT_IRIS_INDICES, 133, 33),
89
+ (FaceMeshDetector.RIGHT_IRIS_INDICES, 362, 263),
90
+ ]:
91
+ iris_pts = np.array([_lm_px(lm, i, w, h) for i in iris_idx], dtype=np.int32)
92
+ center = iris_pts[0]
93
+ if len(iris_pts) >= 5:
94
+ radii = [np.linalg.norm(iris_pts[j] - center) for j in range(1, 5)]
95
+ radius = max(int(np.mean(radii)), 2)
96
+ cv2.circle(frame, tuple(center), radius, _MAGENTA, 2, cv2.LINE_AA)
97
+ cv2.circle(frame, tuple(center), 2, _WHITE, -1, cv2.LINE_AA)
98
+ eye_cx = int((lm[eye_inner, 0] + lm[eye_outer, 0]) / 2.0 * w)
99
+ eye_cy = int((lm[eye_inner, 1] + lm[eye_outer, 1]) / 2.0 * h)
100
+ dx, dy = center[0] - eye_cx, center[1] - eye_cy
101
+ cv2.line(frame, tuple(center), (int(center[0] + dx * 3), int(center[1] + dy * 3)), _RED, 1, cv2.LINE_AA)
102
+
103
+
104
+ def _draw_hud(frame, result, model_name):
105
+ """Draw status bar and detail overlay matching live_demo.py."""
106
+ h, w = frame.shape[:2]
107
+ is_focused = result["is_focused"]
108
+ status = "FOCUSED" if is_focused else "NOT FOCUSED"
109
+ color = _GREEN if is_focused else _RED
110
+
111
+ # Top bar
112
+ cv2.rectangle(frame, (0, 0), (w, 55), (0, 0, 0), -1)
113
+ cv2.putText(frame, status, (10, 28), _FONT, 0.8, color, 2, cv2.LINE_AA)
114
+ cv2.putText(frame, model_name.upper(), (w - 150, 28), _FONT, 0.45, _WHITE, 1, cv2.LINE_AA)
115
+
116
+ # Detail line
117
+ conf = result.get("mlp_prob", result.get("raw_score", 0.0))
118
+ mar_s = f" MAR:{result['mar']:.2f}" if result.get("mar") is not None else ""
119
+ sf = result.get("s_face", 0)
120
+ se = result.get("s_eye", 0)
121
+ detail = f"conf:{conf:.2f} S_face:{sf:.2f} S_eye:{se:.2f}{mar_s}"
122
+ cv2.putText(frame, detail, (10, 48), _FONT, 0.4, _WHITE, 1, cv2.LINE_AA)
123
+
124
+ # Head pose (top right)
125
+ if result.get("yaw") is not None:
126
+ cv2.putText(frame, f"yaw:{result['yaw']:+.0f} pitch:{result['pitch']:+.0f} roll:{result['roll']:+.0f}",
127
+ (w - 280, 48), _FONT, 0.4, (180, 180, 180), 1, cv2.LINE_AA)
128
+
129
+ # Yawn indicator
130
+ if result.get("is_yawning"):
131
+ cv2.putText(frame, "YAWN", (10, 75), _FONT, 0.7, _ORANGE, 2, cv2.LINE_AA)
132
+
133
+ # Landmark indices used for face mesh drawing on client (union of all groups).
134
+ # Sending only these instead of all 478 saves ~60% of the landmarks payload.
135
+ _MESH_INDICES = sorted(set(
136
+ [10,338,297,332,284,251,389,356,454,323,361,288,397,365,379,378,400,377,152,148,176,149,150,136,172,58,132,93,234,127,162,21,54,103,67,109] # face oval
137
+ + [33,7,163,144,145,153,154,155,133,173,157,158,159,160,161,246] # left eye
138
+ + [362,382,381,380,374,373,390,249,263,466,388,387,386,385,384,398] # right eye
139
+ + [468,469,470,471,472, 473,474,475,476,477] # irises
140
+ + [70,63,105,66,107,55,65,52,53,46] # left eyebrow
141
+ + [300,293,334,296,336,285,295,282,283,276] # right eyebrow
142
+ + [6,197,195,5,4,1,19,94,2] # nose bridge
143
+ + [61,146,91,181,84,17,314,405,321,375,291,409,270,269,267,0,37,39,40,185] # lips outer
144
+ + [78,95,88,178,87,14,317,402,318,324,308,415,310,311,312,13,82,81,80,191] # lips inner
145
+ + [33,160,158,133,153,145] # left EAR key points
146
+ + [362,385,387,263,373,380] # right EAR key points
147
+ ))
148
+ # Build a lookup: original_index -> position in sparse array, so client can reconstruct.
149
+ _MESH_INDEX_SET = set(_MESH_INDICES)
150
+
151
+ # Initialize FastAPI app
152
+ app = FastAPI(title="Focus Guard API")
153
+
154
+ # Add CORS middleware
155
+ app.add_middleware(
156
+ CORSMiddleware,
157
+ allow_origins=["*"],
158
+ allow_credentials=True,
159
+ allow_methods=["*"],
160
+ allow_headers=["*"],
161
+ )
162
+
163
+ # Global variables
164
+ db_path = "focus_guard.db"
165
+ pcs = set()
166
+ _cached_model_name = "mlp" # in-memory cache, updated via /api/settings
167
+
168
+ async def _wait_for_ice_gathering(pc: RTCPeerConnection):
169
+ if pc.iceGatheringState == "complete":
170
+ return
171
+ done = asyncio.Event()
172
+
173
+ @pc.on("icegatheringstatechange")
174
+ def _on_state_change():
175
+ if pc.iceGatheringState == "complete":
176
+ done.set()
177
+
178
+ await done.wait()
179
+
180
+ # ================ DATABASE MODELS ================
181
+
182
+ async def init_database():
183
+ """Initialize SQLite database with required tables"""
184
+ async with aiosqlite.connect(db_path) as db:
185
+ # FocusSessions table
186
+ await db.execute("""
187
+ CREATE TABLE IF NOT EXISTS focus_sessions (
188
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
189
+ start_time TIMESTAMP NOT NULL,
190
+ end_time TIMESTAMP,
191
+ duration_seconds INTEGER DEFAULT 0,
192
+ focus_score REAL DEFAULT 0.0,
193
+ total_frames INTEGER DEFAULT 0,
194
+ focused_frames INTEGER DEFAULT 0,
195
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
196
+ )
197
+ """)
198
+
199
+ # FocusEvents table
200
+ await db.execute("""
201
+ CREATE TABLE IF NOT EXISTS focus_events (
202
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
203
+ session_id INTEGER NOT NULL,
204
+ timestamp TIMESTAMP NOT NULL,
205
+ is_focused BOOLEAN NOT NULL,
206
+ confidence REAL NOT NULL,
207
+ detection_data TEXT,
208
+ FOREIGN KEY (session_id) REFERENCES focus_sessions (id)
209
+ )
210
+ """)
211
+
212
+ # UserSettings table
213
+ await db.execute("""
214
+ CREATE TABLE IF NOT EXISTS user_settings (
215
+ id INTEGER PRIMARY KEY CHECK (id = 1),
216
+ sensitivity INTEGER DEFAULT 6,
217
+ notification_enabled BOOLEAN DEFAULT 1,
218
+ notification_threshold INTEGER DEFAULT 30,
219
+ frame_rate INTEGER DEFAULT 30,
220
+ model_name TEXT DEFAULT 'mlp'
221
+ )
222
+ """)
223
+
224
+ # Insert default settings if not exists
225
+ await db.execute("""
226
+ INSERT OR IGNORE INTO user_settings (id, sensitivity, notification_enabled, notification_threshold, frame_rate, model_name)
227
+ VALUES (1, 6, 1, 30, 30, 'mlp')
228
+ """)
229
+
230
+ await db.commit()
231
+
232
+ # ================ PYDANTIC MODELS ================
233
+
234
+ class SessionCreate(BaseModel):
235
+ pass
236
+
237
+ class SessionEnd(BaseModel):
238
+ session_id: int
239
+
240
+ class SettingsUpdate(BaseModel):
241
+ sensitivity: Optional[int] = None
242
+ notification_enabled: Optional[bool] = None
243
+ notification_threshold: Optional[int] = None
244
+ frame_rate: Optional[int] = None
245
+ model_name: Optional[str] = None
246
+
247
+ class VideoTransformTrack(VideoStreamTrack):
248
+ def __init__(self, track, session_id: int, get_channel: Callable[[], Any]):
249
+ super().__init__()
250
+ self.track = track
251
+ self.session_id = session_id
252
+ self.get_channel = get_channel
253
+ self.last_inference_time = 0
254
+ self.min_inference_interval = 1 / 60
255
+ self.last_frame = None
256
+
257
+ async def recv(self):
258
+ frame = await self.track.recv()
259
+ img = frame.to_ndarray(format="bgr24")
260
+ if img is None:
261
+ return frame
262
+
263
+ # Normalize size for inference/drawing
264
+ img = cv2.resize(img, (640, 480))
265
+
266
+ now = datetime.now().timestamp()
267
+ do_infer = (now - self.last_inference_time) >= self.min_inference_interval
268
+
269
+ if do_infer:
270
+ self.last_inference_time = now
271
+
272
+ model_name = _cached_model_name
273
+ if model_name not in pipelines or pipelines.get(model_name) is None:
274
+ model_name = 'mlp'
275
+ active_pipeline = pipelines.get(model_name)
276
+
277
+ if active_pipeline is not None:
278
+ loop = asyncio.get_event_loop()
279
+ out = await loop.run_in_executor(
280
+ _inference_executor,
281
+ _process_frame_safe,
282
+ active_pipeline,
283
+ img,
284
+ model_name,
285
+ )
286
+ is_focused = out["is_focused"]
287
+ confidence = out.get("mlp_prob", out.get("raw_score", 0.0))
288
+ metadata = {"s_face": out.get("s_face", 0.0), "s_eye": out.get("s_eye", 0.0), "mar": out.get("mar", 0.0), "model": model_name}
289
+
290
+ # Draw face mesh + HUD on the video frame
291
+ h_f, w_f = img.shape[:2]
292
+ lm = out.get("landmarks")
293
+ if lm is not None:
294
+ _draw_face_mesh(img, lm, w_f, h_f)
295
+ _draw_hud(img, out, model_name)
296
+ else:
297
+ is_focused = False
298
+ confidence = 0.0
299
+ metadata = {"model": model_name}
300
+ cv2.rectangle(img, (0, 0), (img.shape[1], 55), (0, 0, 0), -1)
301
+ cv2.putText(img, "NO MODEL", (10, 28), _FONT, 0.8, _RED, 2, cv2.LINE_AA)
302
+
303
+ if self.session_id:
304
+ await store_focus_event(self.session_id, is_focused, confidence, metadata)
305
+
306
+ channel = self.get_channel()
307
+ if channel and channel.readyState == "open":
308
+ try:
309
+ channel.send(json.dumps({
310
+ "type": "detection",
311
+ "focused": is_focused,
312
+ "confidence": round(confidence, 3),
313
+ "detections": [],
314
+ "model": model_name,
315
+ }))
316
+ except Exception:
317
+ pass
318
+
319
+ self.last_frame = img
320
+ elif self.last_frame is not None:
321
+ img = self.last_frame
322
+
323
+ new_frame = VideoFrame.from_ndarray(img, format="bgr24")
324
+ new_frame.pts = frame.pts
325
+ new_frame.time_base = frame.time_base
326
+ return new_frame
327
+
328
+ # ================ DATABASE OPERATIONS ================
329
+
330
+ async def create_session():
331
+ async with aiosqlite.connect(db_path) as db:
332
+ cursor = await db.execute(
333
+ "INSERT INTO focus_sessions (start_time) VALUES (?)",
334
+ (datetime.now().isoformat(),)
335
+ )
336
+ await db.commit()
337
+ return cursor.lastrowid
338
+
339
+ async def end_session(session_id: int):
340
+ async with aiosqlite.connect(db_path) as db:
341
+ cursor = await db.execute(
342
+ "SELECT start_time, total_frames, focused_frames FROM focus_sessions WHERE id = ?",
343
+ (session_id,)
344
+ )
345
+ row = await cursor.fetchone()
346
+
347
+ if not row:
348
+ return None
349
+
350
+ start_time_str, total_frames, focused_frames = row
351
+ start_time = datetime.fromisoformat(start_time_str)
352
+ end_time = datetime.now()
353
+ duration = (end_time - start_time).total_seconds()
354
+ focus_score = focused_frames / total_frames if total_frames > 0 else 0.0
355
+
356
+ await db.execute("""
357
+ UPDATE focus_sessions
358
+ SET end_time = ?, duration_seconds = ?, focus_score = ?
359
+ WHERE id = ?
360
+ """, (end_time.isoformat(), int(duration), focus_score, session_id))
361
+
362
+ await db.commit()
363
+
364
+ return {
365
+ 'session_id': session_id,
366
+ 'start_time': start_time_str,
367
+ 'end_time': end_time.isoformat(),
368
+ 'duration_seconds': int(duration),
369
+ 'focus_score': round(focus_score, 3),
370
+ 'total_frames': total_frames,
371
+ 'focused_frames': focused_frames
372
+ }
373
+
374
+ async def store_focus_event(session_id: int, is_focused: bool, confidence: float, metadata: dict):
375
+ async with aiosqlite.connect(db_path) as db:
376
+ await db.execute("""
377
+ INSERT INTO focus_events (session_id, timestamp, is_focused, confidence, detection_data)
378
+ VALUES (?, ?, ?, ?, ?)
379
+ """, (session_id, datetime.now().isoformat(), is_focused, confidence, json.dumps(metadata)))
380
+
381
+ await db.execute("""
382
+ UPDATE focus_sessions
383
+ SET total_frames = total_frames + 1,
384
+ focused_frames = focused_frames + ?
385
+ WHERE id = ?
386
+ """, (1 if is_focused else 0, session_id))
387
+ await db.commit()
388
+
389
+
390
+ class _EventBuffer:
391
+ """Buffer focus events in memory and flush to DB in batches to avoid per-frame DB writes."""
392
+
393
+ def __init__(self, flush_interval: float = 2.0):
394
+ self._buf: list = []
395
+ self._lock = asyncio.Lock()
396
+ self._flush_interval = flush_interval
397
+ self._task: asyncio.Task | None = None
398
+ self._total_frames = 0
399
+ self._focused_frames = 0
400
+
401
+ def start(self):
402
+ if self._task is None:
403
+ self._task = asyncio.create_task(self._flush_loop())
404
+
405
+ async def stop(self):
406
+ if self._task:
407
+ self._task.cancel()
408
+ try:
409
+ await self._task
410
+ except asyncio.CancelledError:
411
+ pass
412
+ self._task = None
413
+ await self._flush()
414
+
415
+ def add(self, session_id: int, is_focused: bool, confidence: float, metadata: dict):
416
+ self._buf.append((session_id, datetime.now().isoformat(), is_focused, confidence, json.dumps(metadata)))
417
+ self._total_frames += 1
418
+ if is_focused:
419
+ self._focused_frames += 1
420
+
421
+ async def _flush_loop(self):
422
+ while True:
423
+ await asyncio.sleep(self._flush_interval)
424
+ await self._flush()
425
+
426
+ async def _flush(self):
427
+ async with self._lock:
428
+ if not self._buf:
429
+ return
430
+ batch = self._buf[:]
431
+ total = self._total_frames
432
+ focused = self._focused_frames
433
+ self._buf.clear()
434
+ self._total_frames = 0
435
+ self._focused_frames = 0
436
+
437
+ if not batch:
438
+ return
439
+
440
+ session_id = batch[0][0]
441
+ try:
442
+ async with aiosqlite.connect(db_path) as db:
443
+ await db.executemany("""
444
+ INSERT INTO focus_events (session_id, timestamp, is_focused, confidence, detection_data)
445
+ VALUES (?, ?, ?, ?, ?)
446
+ """, batch)
447
+ await db.execute("""
448
+ UPDATE focus_sessions
449
+ SET total_frames = total_frames + ?,
450
+ focused_frames = focused_frames + ?
451
+ WHERE id = ?
452
+ """, (total, focused, session_id))
453
+ await db.commit()
454
+ except Exception as e:
455
+ print(f"[DB] Flush error: {e}")
456
+
457
+ # ================ STARTUP/SHUTDOWN ================
458
+
459
+ pipelines = {
460
+ "geometric": None,
461
+ "mlp": None,
462
+ "hybrid": None,
463
+ "xgboost": None,
464
+ }
465
+
466
+ # Thread pool for CPU-bound inference so the event loop stays responsive.
467
+ _inference_executor = concurrent.futures.ThreadPoolExecutor(
468
+ max_workers=4,
469
+ thread_name_prefix="inference",
470
+ )
471
+ # One lock per pipeline so shared state (TemporalTracker, etc.) is not corrupted when
472
+ # multiple frames are processed in parallel by the thread pool.
473
+ _pipeline_locks = {name: threading.Lock() for name in ("geometric", "mlp", "hybrid", "xgboost")}
474
+
475
+
476
+ def _process_frame_safe(pipeline, frame, model_name: str):
477
+ """Run process_frame in executor with per-pipeline lock."""
478
+ with _pipeline_locks[model_name]:
479
+ return pipeline.process_frame(frame)
480
+
481
+ @app.on_event("startup")
482
+ async def startup_event():
483
+ global pipelines, _cached_model_name
484
+ print(" Starting Focus Guard API...")
485
+ await init_database()
486
+ # Load cached model name from DB
487
+ async with aiosqlite.connect(db_path) as db:
488
+ cursor = await db.execute("SELECT model_name FROM user_settings WHERE id = 1")
489
+ row = await cursor.fetchone()
490
+ if row:
491
+ _cached_model_name = row[0]
492
+ print("[OK] Database initialized")
493
+
494
+ try:
495
+ pipelines["geometric"] = FaceMeshPipeline()
496
+ print("[OK] FaceMeshPipeline (geometric) loaded")
497
+ except Exception as e:
498
+ print(f"[WARN] FaceMeshPipeline unavailable: {e}")
499
+
500
+ try:
501
+ pipelines["mlp"] = MLPPipeline()
502
+ print("[OK] MLPPipeline loaded")
503
+ except Exception as e:
504
+ print(f"[ERR] Failed to load MLPPipeline: {e}")
505
+
506
+ try:
507
+ pipelines["hybrid"] = HybridFocusPipeline()
508
+ print("[OK] HybridFocusPipeline loaded")
509
+ except Exception as e:
510
+ print(f"[WARN] HybridFocusPipeline unavailable: {e}")
511
+
512
+ try:
513
+ pipelines["xgboost"] = XGBoostPipeline()
514
+ print("[OK] XGBoostPipeline loaded")
515
+ except Exception as e:
516
+ print(f"[ERR] Failed to load XGBoostPipeline: {e}")
517
+
518
+ @app.on_event("shutdown")
519
+ async def shutdown_event():
520
+ _inference_executor.shutdown(wait=False)
521
+ print(" Shutting down Focus Guard API...")
522
+
523
+ # ================ WEBRTC SIGNALING ================
524
+
525
+ @app.post("/api/webrtc/offer")
526
+ async def webrtc_offer(offer: dict):
527
+ try:
528
+ print(f"Received WebRTC offer")
529
+
530
+ pc = RTCPeerConnection()
531
+ pcs.add(pc)
532
+
533
+ session_id = await create_session()
534
+ print(f"Created session: {session_id}")
535
+
536
+ channel_ref = {"channel": None}
537
+
538
+ @pc.on("datachannel")
539
+ def on_datachannel(channel):
540
+ print(f"Data channel opened")
541
+ channel_ref["channel"] = channel
542
+
543
+ @pc.on("track")
544
+ def on_track(track):
545
+ print(f"Received track: {track.kind}")
546
+ if track.kind == "video":
547
+ local_track = VideoTransformTrack(track, session_id, lambda: channel_ref["channel"])
548
+ pc.addTrack(local_track)
549
+ print(f"Video track added")
550
+
551
+ @track.on("ended")
552
+ async def on_ended():
553
+ print(f"Track ended")
554
+
555
+ @pc.on("connectionstatechange")
556
+ async def on_connectionstatechange():
557
+ print(f"Connection state changed: {pc.connectionState}")
558
+ if pc.connectionState in ("failed", "closed", "disconnected"):
559
+ try:
560
+ await end_session(session_id)
561
+ except Exception as e:
562
+ print(f"⚠Error ending session: {e}")
563
+ pcs.discard(pc)
564
+ await pc.close()
565
+
566
+ await pc.setRemoteDescription(RTCSessionDescription(sdp=offer["sdp"], type=offer["type"]))
567
+ print(f"Remote description set")
568
+
569
+ answer = await pc.createAnswer()
570
+ await pc.setLocalDescription(answer)
571
+ print(f"Answer created")
572
+
573
+ await _wait_for_ice_gathering(pc)
574
+ print(f"ICE gathering complete")
575
+
576
+ return {"sdp": pc.localDescription.sdp, "type": pc.localDescription.type, "session_id": session_id}
577
+
578
+ except Exception as e:
579
+ print(f"WebRTC offer error: {e}")
580
+ import traceback
581
+ traceback.print_exc()
582
+ raise HTTPException(status_code=500, detail=f"WebRTC error: {str(e)}")
583
+
584
+ # ================ WEBSOCKET ================
585
+
586
+ @app.websocket("/ws/video")
587
+ async def websocket_endpoint(websocket: WebSocket):
588
+ await websocket.accept()
589
+ session_id = None
590
+ frame_count = 0
591
+ running = True
592
+ event_buffer = _EventBuffer(flush_interval=2.0)
593
+
594
+ # Latest frame slot — only the most recent frame is kept, older ones are dropped.
595
+ # Using a dict so nested functions can mutate without nonlocal issues.
596
+ _slot = {"frame": None}
597
+ _frame_ready = asyncio.Event()
598
+
599
+ async def _receive_loop():
600
+ """Receive messages as fast as possible. Binary = frame, text = control."""
601
+ nonlocal session_id, running
602
+ try:
603
+ while running:
604
+ msg = await websocket.receive()
605
+ msg_type = msg.get("type", "")
606
+
607
+ if msg_type == "websocket.disconnect":
608
+ running = False
609
+ _frame_ready.set()
610
+ return
611
+
612
+ # Binary message → JPEG frame (fast path, no base64)
613
+ raw_bytes = msg.get("bytes")
614
+ if raw_bytes is not None and len(raw_bytes) > 0:
615
+ _slot["frame"] = raw_bytes
616
+ _frame_ready.set()
617
+ continue
618
+
619
+ # Text message → JSON control command (or legacy base64 frame)
620
+ text = msg.get("text")
621
+ if not text:
622
+ continue
623
+ data = json.loads(text)
624
+
625
+ if data["type"] == "frame":
626
+ # Legacy base64 path (fallback)
627
+ _slot["frame"] = base64.b64decode(data["image"])
628
+ _frame_ready.set()
629
+
630
+ elif data["type"] == "start_session":
631
+ session_id = await create_session()
632
+ event_buffer.start()
633
+ for p in pipelines.values():
634
+ if p is not None and hasattr(p, "reset_session"):
635
+ p.reset_session()
636
+ await websocket.send_json({"type": "session_started", "session_id": session_id})
637
+
638
+ elif data["type"] == "end_session":
639
+ if session_id:
640
+ await event_buffer.stop()
641
+ summary = await end_session(session_id)
642
+ if summary:
643
+ await websocket.send_json({"type": "session_ended", "summary": summary})
644
+ session_id = None
645
+ except WebSocketDisconnect:
646
+ running = False
647
+ _frame_ready.set()
648
+ except Exception as e:
649
+ print(f"[WS] receive error: {e}")
650
+ running = False
651
+ _frame_ready.set()
652
+
653
+ async def _process_loop():
654
+ """Process only the latest frame, dropping stale ones."""
655
+ nonlocal frame_count, running
656
+ loop = asyncio.get_event_loop()
657
+ while running:
658
+ await _frame_ready.wait()
659
+ _frame_ready.clear()
660
+ if not running:
661
+ return
662
+
663
+ # Grab latest frame and clear slot
664
+ raw = _slot["frame"]
665
+ _slot["frame"] = None
666
+ if raw is None:
667
+ continue
668
+
669
+ try:
670
+ nparr = np.frombuffer(raw, np.uint8)
671
+ frame = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
672
+ if frame is None:
673
+ continue
674
+ frame = cv2.resize(frame, (640, 480))
675
+
676
+ model_name = _cached_model_name
677
+ if model_name not in pipelines or pipelines.get(model_name) is None:
678
+ model_name = "mlp"
679
+ active_pipeline = pipelines.get(model_name)
680
+
681
+ landmarks_list = None
682
+ if active_pipeline is not None:
683
+ out = await loop.run_in_executor(
684
+ _inference_executor,
685
+ _process_frame_safe,
686
+ active_pipeline,
687
+ frame,
688
+ model_name,
689
+ )
690
+ is_focused = out["is_focused"]
691
+ confidence = out.get("mlp_prob", out.get("raw_score", 0.0))
692
+
693
+ lm = out.get("landmarks")
694
+ if lm is not None:
695
+ # Send all 478 landmarks as flat array for tessellation drawing
696
+ landmarks_list = [
697
+ [round(float(lm[i, 0]), 3), round(float(lm[i, 1]), 3)]
698
+ for i in range(lm.shape[0])
699
+ ]
700
+
701
+ if session_id:
702
+ event_buffer.add(session_id, is_focused, confidence, {
703
+ "s_face": out.get("s_face", 0.0),
704
+ "s_eye": out.get("s_eye", 0.0),
705
+ "mar": out.get("mar", 0.0),
706
+ "model": model_name,
707
+ })
708
+ else:
709
+ is_focused = False
710
+ confidence = 0.0
711
+
712
+ resp = {
713
+ "type": "detection",
714
+ "focused": is_focused,
715
+ "confidence": round(confidence, 3),
716
+ "detections": [],
717
+ "model": model_name,
718
+ "fc": frame_count,
719
+ "frame_count": frame_count,
720
+ }
721
+ if active_pipeline is not None:
722
+ # Send detailed metrics for HUD
723
+ if out.get("yaw") is not None:
724
+ resp["yaw"] = round(out["yaw"], 1)
725
+ resp["pitch"] = round(out["pitch"], 1)
726
+ resp["roll"] = round(out["roll"], 1)
727
+ if out.get("mar") is not None:
728
+ resp["mar"] = round(out["mar"], 3)
729
+ resp["sf"] = round(out.get("s_face", 0), 3)
730
+ resp["se"] = round(out.get("s_eye", 0), 3)
731
+ if landmarks_list is not None:
732
+ resp["lm"] = landmarks_list
733
+ await websocket.send_json(resp)
734
+ frame_count += 1
735
+ except Exception as e:
736
+ print(f"[WS] process error: {e}")
737
+
738
+ try:
739
+ await asyncio.gather(_receive_loop(), _process_loop())
740
+ except Exception:
741
+ pass
742
+ finally:
743
+ running = False
744
+ if session_id:
745
+ await event_buffer.stop()
746
+ await end_session(session_id)
747
+
748
+ # ================ API ENDPOINTS ================
749
+
750
+ @app.post("/api/sessions/start")
751
+ async def api_start_session():
752
+ session_id = await create_session()
753
+ return {"session_id": session_id}
754
+
755
+ @app.post("/api/sessions/end")
756
+ async def api_end_session(data: SessionEnd):
757
+ summary = await end_session(data.session_id)
758
+ if not summary: raise HTTPException(status_code=404, detail="Session not found")
759
+ return summary
760
+
761
+ @app.get("/api/sessions")
762
+ async def get_sessions(filter: str = "all", limit: int = 50, offset: int = 0):
763
+ async with aiosqlite.connect(db_path) as db:
764
+ db.row_factory = aiosqlite.Row
765
+
766
+ # NEW: If importing/exporting all, remove limit if special flag or high limit
767
+ # For simplicity: if limit is -1, return all
768
+ limit_clause = "LIMIT ? OFFSET ?"
769
+ params = []
770
+
771
+ base_query = "SELECT * FROM focus_sessions"
772
+ where_clause = ""
773
+
774
+ if filter == "today":
775
+ date_filter = datetime.now().replace(hour=0, minute=0, second=0, microsecond=0)
776
+ where_clause = " WHERE start_time >= ?"
777
+ params.append(date_filter.isoformat())
778
+ elif filter == "week":
779
+ date_filter = datetime.now() - timedelta(days=7)
780
+ where_clause = " WHERE start_time >= ?"
781
+ params.append(date_filter.isoformat())
782
+ elif filter == "month":
783
+ date_filter = datetime.now() - timedelta(days=30)
784
+ where_clause = " WHERE start_time >= ?"
785
+ params.append(date_filter.isoformat())
786
+ elif filter == "all":
787
+ # Just ensure we only get completed sessions or all sessions
788
+ where_clause = " WHERE end_time IS NOT NULL"
789
+
790
+ query = f"{base_query}{where_clause} ORDER BY start_time DESC"
791
+
792
+ # Handle Limit for Exports
793
+ if limit == -1:
794
+ # No limit clause for export
795
+ pass
796
+ else:
797
+ query += f" {limit_clause}"
798
+ params.extend([limit, offset])
799
+
800
+ cursor = await db.execute(query, tuple(params))
801
+ rows = await cursor.fetchall()
802
+ return [dict(row) for row in rows]
803
+
804
+ # --- NEW: Import Endpoint ---
805
+ @app.post("/api/import")
806
+ async def import_sessions(sessions: List[dict]):
807
+ count = 0
808
+ try:
809
+ async with aiosqlite.connect(db_path) as db:
810
+ for session in sessions:
811
+ # Use .get() to handle potential missing fields from older versions or edits
812
+ await db.execute("""
813
+ INSERT INTO focus_sessions (start_time, end_time, duration_seconds, focus_score, total_frames, focused_frames, created_at)
814
+ VALUES (?, ?, ?, ?, ?, ?, ?)
815
+ """, (
816
+ session.get('start_time'),
817
+ session.get('end_time'),
818
+ session.get('duration_seconds', 0),
819
+ session.get('focus_score', 0.0),
820
+ session.get('total_frames', 0),
821
+ session.get('focused_frames', 0),
822
+ session.get('created_at', session.get('start_time'))
823
+ ))
824
+ count += 1
825
+ await db.commit()
826
+ return {"status": "success", "count": count}
827
+ except Exception as e:
828
+ print(f"Import Error: {e}")
829
+ return {"status": "error", "message": str(e)}
830
+
831
+ # --- NEW: Clear History Endpoint ---
832
+ @app.delete("/api/history")
833
+ async def clear_history():
834
+ try:
835
+ async with aiosqlite.connect(db_path) as db:
836
+ # Delete events first (foreign key good practice)
837
+ await db.execute("DELETE FROM focus_events")
838
+ await db.execute("DELETE FROM focus_sessions")
839
+ await db.commit()
840
+ return {"status": "success", "message": "History cleared"}
841
+ except Exception as e:
842
+ return {"status": "error", "message": str(e)}
843
+
844
+ @app.get("/api/sessions/{session_id}")
845
+ async def get_session(session_id: int):
846
+ async with aiosqlite.connect(db_path) as db:
847
+ db.row_factory = aiosqlite.Row
848
+ cursor = await db.execute("SELECT * FROM focus_sessions WHERE id = ?", (session_id,))
849
+ row = await cursor.fetchone()
850
+ if not row: raise HTTPException(status_code=404, detail="Session not found")
851
+ session = dict(row)
852
+ cursor = await db.execute("SELECT * FROM focus_events WHERE session_id = ? ORDER BY timestamp", (session_id,))
853
+ events = [dict(r) for r in await cursor.fetchall()]
854
+ session['events'] = events
855
+ return session
856
+
857
+ @app.get("/api/settings")
858
+ async def get_settings():
859
+ async with aiosqlite.connect(db_path) as db:
860
+ db.row_factory = aiosqlite.Row
861
+ cursor = await db.execute("SELECT * FROM user_settings WHERE id = 1")
862
+ row = await cursor.fetchone()
863
+ if row: return dict(row)
864
+ else: return {'sensitivity': 6, 'notification_enabled': True, 'notification_threshold': 30, 'frame_rate': 30, 'model_name': 'mlp'}
865
+
866
+ @app.put("/api/settings")
867
+ async def update_settings(settings: SettingsUpdate):
868
+ async with aiosqlite.connect(db_path) as db:
869
+ cursor = await db.execute("SELECT id FROM user_settings WHERE id = 1")
870
+ exists = await cursor.fetchone()
871
+ if not exists:
872
+ await db.execute("INSERT INTO user_settings (id, sensitivity) VALUES (1, 6)")
873
+ await db.commit()
874
+
875
+ updates = []
876
+ params = []
877
+ if settings.sensitivity is not None:
878
+ updates.append("sensitivity = ?")
879
+ params.append(max(1, min(10, settings.sensitivity)))
880
+ if settings.notification_enabled is not None:
881
+ updates.append("notification_enabled = ?")
882
+ params.append(settings.notification_enabled)
883
+ if settings.notification_threshold is not None:
884
+ updates.append("notification_threshold = ?")
885
+ params.append(max(5, min(300, settings.notification_threshold)))
886
+ if settings.frame_rate is not None:
887
+ updates.append("frame_rate = ?")
888
+ params.append(max(5, min(60, settings.frame_rate)))
889
+ if settings.model_name is not None and settings.model_name in pipelines and pipelines[settings.model_name] is not None:
890
+ updates.append("model_name = ?")
891
+ params.append(settings.model_name)
892
+ global _cached_model_name
893
+ _cached_model_name = settings.model_name
894
+
895
+ if updates:
896
+ query = f"UPDATE user_settings SET {', '.join(updates)} WHERE id = 1"
897
+ await db.execute(query, params)
898
+ await db.commit()
899
+ return {"status": "success", "updated": len(updates) > 0}
900
+
901
+ @app.get("/api/stats/summary")
902
+ async def get_stats_summary():
903
+ async with aiosqlite.connect(db_path) as db:
904
+ cursor = await db.execute("SELECT COUNT(*) FROM focus_sessions WHERE end_time IS NOT NULL")
905
+ total_sessions = (await cursor.fetchone())[0]
906
+ cursor = await db.execute("SELECT SUM(duration_seconds) FROM focus_sessions WHERE end_time IS NOT NULL")
907
+ total_focus_time = (await cursor.fetchone())[0] or 0
908
+ cursor = await db.execute("SELECT AVG(focus_score) FROM focus_sessions WHERE end_time IS NOT NULL")
909
+ avg_focus_score = (await cursor.fetchone())[0] or 0.0
910
+ cursor = await db.execute("SELECT DISTINCT DATE(start_time) as session_date FROM focus_sessions WHERE end_time IS NOT NULL ORDER BY session_date DESC")
911
+ dates = [row[0] for row in await cursor.fetchall()]
912
+
913
+ streak_days = 0
914
+ if dates:
915
+ current_date = datetime.now().date()
916
+ for i, date_str in enumerate(dates):
917
+ session_date = datetime.fromisoformat(date_str).date()
918
+ expected_date = current_date - timedelta(days=i)
919
+ if session_date == expected_date: streak_days += 1
920
+ else: break
921
+ return {
922
+ 'total_sessions': total_sessions,
923
+ 'total_focus_time': int(total_focus_time),
924
+ 'avg_focus_score': round(avg_focus_score, 3),
925
+ 'streak_days': streak_days
926
+ }
927
+
928
+ @app.get("/api/models")
929
+ async def get_available_models():
930
+ """Return list of loaded model names and which is currently active."""
931
+ available = [name for name, p in pipelines.items() if p is not None]
932
+ async with aiosqlite.connect(db_path) as db:
933
+ cursor = await db.execute("SELECT model_name FROM user_settings WHERE id = 1")
934
+ row = await cursor.fetchone()
935
+ current = row[0] if row else "mlp"
936
+ if current not in available and available:
937
+ current = available[0]
938
+ return {"available": available, "current": current}
939
+
940
+ @app.get("/api/mesh-topology")
941
+ async def get_mesh_topology():
942
+ """Return tessellation edge pairs for client-side face mesh drawing (cached by client)."""
943
+ return {"tessellation": _TESSELATION_CONNS}
944
+
945
+ @app.get("/health")
946
+ async def health_check():
947
+ available = [name for name, p in pipelines.items() if p is not None]
948
+ return {"status": "healthy", "models_loaded": available, "database": os.path.exists(db_path)}
949
+
950
+ # ================ STATIC FILES (SPA SUPPORT) ================
951
+
952
+ # Resolve frontend dir from this file so it works regardless of cwd.
953
+ # Prefer a built `dist/` app when present, otherwise fall back to `static/`.
954
+ _BASE_DIR = Path(__file__).resolve().parent
955
+ _DIST_DIR = _BASE_DIR / "dist"
956
+ _STATIC_DIR = _BASE_DIR / "static"
957
+ _FRONTEND_DIR = _DIST_DIR if (_DIST_DIR / "index.html").is_file() else _STATIC_DIR
958
+ _ASSETS_DIR = _FRONTEND_DIR / "assets"
959
+
960
+ # 1. Mount the assets folder (JS/CSS) first so /assets/* is never caught by catch-all
961
+ if _ASSETS_DIR.is_dir():
962
+ app.mount("/assets", StaticFiles(directory=str(_ASSETS_DIR)), name="assets")
963
+
964
+ # 2. Catch-all for SPA: serve index.html for app routes, never for /assets (would break JS MIME type)
965
+ @app.get("/{full_path:path}")
966
+ async def serve_react_app(full_path: str, request: Request):
967
+ if full_path.startswith("api") or full_path.startswith("ws"):
968
+ raise HTTPException(status_code=404, detail="Not Found")
969
+ # Don't serve HTML for asset paths; let them 404 so we don't break module script loading
970
+ if full_path.startswith("assets") or full_path.startswith("assets/"):
971
+ raise HTTPException(status_code=404, detail="Not Found")
972
+
973
+ file_path = _FRONTEND_DIR / full_path
974
+ if full_path and file_path.is_file():
975
+ return FileResponse(str(file_path))
976
+
977
+ index_path = _FRONTEND_DIR / "index.html"
978
+ if index_path.is_file():
979
+ return FileResponse(str(index_path))
980
+ return {"message": "React app not found. Please run 'npm run build' and copy dist to static if needed."}
main_yolo.py ADDED
@@ -0,0 +1,601 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException, Request
2
+ from fastapi.staticfiles import StaticFiles
3
+ from fastapi.responses import FileResponse
4
+ from fastapi.middleware.cors import CORSMiddleware
5
+ from pydantic import BaseModel
6
+ from typing import Optional, List, Any
7
+ import base64
8
+ import cv2
9
+ import numpy as np
10
+ import aiosqlite
11
+ import json
12
+ from datetime import datetime, timedelta
13
+ import math
14
+ import os
15
+ from pathlib import Path
16
+ from typing import Callable
17
+ import asyncio
18
+
19
+ from aiortc import RTCPeerConnection, RTCSessionDescription, VideoStreamTrack
20
+ from av import VideoFrame
21
+
22
+ from ui.pipeline import MLPPipeline
23
+
24
+ # Initialize FastAPI app
25
+ app = FastAPI(title="Focus Guard API")
26
+
27
+ # Add CORS middleware
28
+ app.add_middleware(
29
+ CORSMiddleware,
30
+ allow_origins=["*"],
31
+ allow_credentials=True,
32
+ allow_methods=["*"],
33
+ allow_headers=["*"],
34
+ )
35
+
36
+ # Global variables
37
+ db_path = "focus_guard.db"
38
+ pcs = set()
39
+
40
+ async def _wait_for_ice_gathering(pc: RTCPeerConnection):
41
+ if pc.iceGatheringState == "complete":
42
+ return
43
+ done = asyncio.Event()
44
+
45
+ @pc.on("icegatheringstatechange")
46
+ def _on_state_change():
47
+ if pc.iceGatheringState == "complete":
48
+ done.set()
49
+
50
+ await done.wait()
51
+
52
+ # ================ DATABASE MODELS ================
53
+
54
+ async def init_database():
55
+ """Initialize SQLite database with required tables"""
56
+ async with aiosqlite.connect(db_path) as db:
57
+ # FocusSessions table
58
+ await db.execute("""
59
+ CREATE TABLE IF NOT EXISTS focus_sessions (
60
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
61
+ start_time TIMESTAMP NOT NULL,
62
+ end_time TIMESTAMP,
63
+ duration_seconds INTEGER DEFAULT 0,
64
+ focus_score REAL DEFAULT 0.0,
65
+ total_frames INTEGER DEFAULT 0,
66
+ focused_frames INTEGER DEFAULT 0,
67
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
68
+ )
69
+ """)
70
+
71
+ # FocusEvents table
72
+ await db.execute("""
73
+ CREATE TABLE IF NOT EXISTS focus_events (
74
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
75
+ session_id INTEGER NOT NULL,
76
+ timestamp TIMESTAMP NOT NULL,
77
+ is_focused BOOLEAN NOT NULL,
78
+ confidence REAL NOT NULL,
79
+ detection_data TEXT,
80
+ FOREIGN KEY (session_id) REFERENCES focus_sessions (id)
81
+ )
82
+ """)
83
+
84
+ # UserSettings table
85
+ await db.execute("""
86
+ CREATE TABLE IF NOT EXISTS user_settings (
87
+ id INTEGER PRIMARY KEY CHECK (id = 1),
88
+ sensitivity INTEGER DEFAULT 6,
89
+ notification_enabled BOOLEAN DEFAULT 1,
90
+ notification_threshold INTEGER DEFAULT 30,
91
+ frame_rate INTEGER DEFAULT 30,
92
+ model_name TEXT DEFAULT 'yolov8n.pt'
93
+ )
94
+ """)
95
+
96
+ # Insert default settings if not exists
97
+ await db.execute("""
98
+ INSERT OR IGNORE INTO user_settings (id, sensitivity, notification_enabled, notification_threshold, frame_rate, model_name)
99
+ VALUES (1, 6, 1, 30, 30, 'yolov8n.pt')
100
+ """)
101
+
102
+ await db.commit()
103
+
104
+ # ================ PYDANTIC MODELS ================
105
+
106
+ class SessionCreate(BaseModel):
107
+ pass
108
+
109
+ class SessionEnd(BaseModel):
110
+ session_id: int
111
+
112
+ class SettingsUpdate(BaseModel):
113
+ sensitivity: Optional[int] = None
114
+ notification_enabled: Optional[bool] = None
115
+ notification_threshold: Optional[int] = None
116
+ frame_rate: Optional[int] = None
117
+
118
+ class VideoTransformTrack(VideoStreamTrack):
119
+ def __init__(self, track, session_id: int, get_channel: Callable[[], Any]):
120
+ super().__init__()
121
+ self.track = track
122
+ self.session_id = session_id
123
+ self.get_channel = get_channel
124
+ self.last_inference_time = 0
125
+ self.min_inference_interval = 1 / 60
126
+ self.last_frame = None
127
+
128
+ async def recv(self):
129
+ frame = await self.track.recv()
130
+ img = frame.to_ndarray(format="bgr24")
131
+ if img is None:
132
+ return frame
133
+
134
+ # Normalize size for inference/drawing
135
+ img = cv2.resize(img, (640, 480))
136
+
137
+ now = datetime.now().timestamp()
138
+ do_infer = (now - self.last_inference_time) >= self.min_inference_interval
139
+
140
+ if do_infer and mlp_pipeline is not None:
141
+ self.last_inference_time = now
142
+ out = mlp_pipeline.process_frame(img)
143
+ is_focused = out["is_focused"]
144
+ confidence = out["mlp_prob"]
145
+ metadata = {"s_face": out["s_face"], "s_eye": out["s_eye"], "mar": out["mar"]}
146
+ detections = []
147
+ status_text = "FOCUSED" if is_focused else "NOT FOCUSED"
148
+ color = (0, 255, 0) if is_focused else (0, 0, 255)
149
+ cv2.putText(img, status_text, (10, 25), cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2)
150
+ cv2.putText(img, f"Confidence: {confidence * 100:.1f}%", (10, 50), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
151
+
152
+ if self.session_id:
153
+ await store_focus_event(self.session_id, is_focused, confidence, metadata)
154
+
155
+ channel = self.get_channel()
156
+ if channel and channel.readyState == "open":
157
+ try:
158
+ channel.send(json.dumps({"type": "detection", "focused": is_focused, "confidence": round(confidence, 3), "detections": detections}))
159
+ except Exception:
160
+ pass
161
+
162
+ self.last_frame = img
163
+ elif self.last_frame is not None:
164
+ img = self.last_frame
165
+
166
+ new_frame = VideoFrame.from_ndarray(img, format="bgr24")
167
+ new_frame.pts = frame.pts
168
+ new_frame.time_base = frame.time_base
169
+ return new_frame
170
+
171
+ # ================ DATABASE OPERATIONS ================
172
+
173
+ async def create_session():
174
+ async with aiosqlite.connect(db_path) as db:
175
+ cursor = await db.execute(
176
+ "INSERT INTO focus_sessions (start_time) VALUES (?)",
177
+ (datetime.now().isoformat(),)
178
+ )
179
+ await db.commit()
180
+ return cursor.lastrowid
181
+
182
+ async def end_session(session_id: int):
183
+ async with aiosqlite.connect(db_path) as db:
184
+ cursor = await db.execute(
185
+ "SELECT start_time, total_frames, focused_frames FROM focus_sessions WHERE id = ?",
186
+ (session_id,)
187
+ )
188
+ row = await cursor.fetchone()
189
+
190
+ if not row:
191
+ return None
192
+
193
+ start_time_str, total_frames, focused_frames = row
194
+ start_time = datetime.fromisoformat(start_time_str)
195
+ end_time = datetime.now()
196
+ duration = (end_time - start_time).total_seconds()
197
+ focus_score = focused_frames / total_frames if total_frames > 0 else 0.0
198
+
199
+ await db.execute("""
200
+ UPDATE focus_sessions
201
+ SET end_time = ?, duration_seconds = ?, focus_score = ?
202
+ WHERE id = ?
203
+ """, (end_time.isoformat(), int(duration), focus_score, session_id))
204
+
205
+ await db.commit()
206
+
207
+ return {
208
+ 'session_id': session_id,
209
+ 'start_time': start_time_str,
210
+ 'end_time': end_time.isoformat(),
211
+ 'duration_seconds': int(duration),
212
+ 'focus_score': round(focus_score, 3),
213
+ 'total_frames': total_frames,
214
+ 'focused_frames': focused_frames
215
+ }
216
+
217
+ async def store_focus_event(session_id: int, is_focused: bool, confidence: float, metadata: dict):
218
+ async with aiosqlite.connect(db_path) as db:
219
+ await db.execute("""
220
+ INSERT INTO focus_events (session_id, timestamp, is_focused, confidence, detection_data)
221
+ VALUES (?, ?, ?, ?, ?)
222
+ """, (session_id, datetime.now().isoformat(), is_focused, confidence, json.dumps(metadata)))
223
+
224
+ await db.execute(f"""
225
+ UPDATE focus_sessions
226
+ SET total_frames = total_frames + 1,
227
+ focused_frames = focused_frames + {1 if is_focused else 0}
228
+ WHERE id = ?
229
+ """, (session_id,))
230
+ await db.commit()
231
+
232
+ # ================ STARTUP/SHUTDOWN ================
233
+
234
+ mlp_pipeline = None
235
+
236
+ @app.on_event("startup")
237
+ async def startup_event():
238
+ global mlp_pipeline
239
+ print(" Starting Focus Guard API...")
240
+ await init_database()
241
+ print("[OK] Database initialized")
242
+
243
+ mlp_pipeline = MLPPipeline()
244
+ print("[OK] MLPPipeline loaded")
245
+
246
+ @app.on_event("shutdown")
247
+ async def shutdown_event():
248
+ print(" Shutting down Focus Guard API...")
249
+
250
+ # ================ WEBRTC SIGNALING ================
251
+
252
+ @app.post("/api/webrtc/offer")
253
+ async def webrtc_offer(offer: dict):
254
+ try:
255
+ print(f"Received WebRTC offer")
256
+
257
+ pc = RTCPeerConnection()
258
+ pcs.add(pc)
259
+
260
+ session_id = await create_session()
261
+ print(f"Created session: {session_id}")
262
+
263
+ channel_ref = {"channel": None}
264
+
265
+ @pc.on("datachannel")
266
+ def on_datachannel(channel):
267
+ print(f"Data channel opened")
268
+ channel_ref["channel"] = channel
269
+
270
+ @pc.on("track")
271
+ def on_track(track):
272
+ print(f"Received track: {track.kind}")
273
+ if track.kind == "video":
274
+ local_track = VideoTransformTrack(track, session_id, lambda: channel_ref["channel"])
275
+ pc.addTrack(local_track)
276
+ print(f"Video track added")
277
+
278
+ @track.on("ended")
279
+ async def on_ended():
280
+ print(f"Track ended")
281
+
282
+ @pc.on("connectionstatechange")
283
+ async def on_connectionstatechange():
284
+ print(f"Connection state changed: {pc.connectionState}")
285
+ if pc.connectionState in ("failed", "closed", "disconnected"):
286
+ try:
287
+ await end_session(session_id)
288
+ except Exception as e:
289
+ print(f"⚠Error ending session: {e}")
290
+ pcs.discard(pc)
291
+ await pc.close()
292
+
293
+ await pc.setRemoteDescription(RTCSessionDescription(sdp=offer["sdp"], type=offer["type"]))
294
+ print(f"Remote description set")
295
+
296
+ answer = await pc.createAnswer()
297
+ await pc.setLocalDescription(answer)
298
+ print(f"Answer created")
299
+
300
+ await _wait_for_ice_gathering(pc)
301
+ print(f"ICE gathering complete")
302
+
303
+ return {"sdp": pc.localDescription.sdp, "type": pc.localDescription.type, "session_id": session_id}
304
+
305
+ except Exception as e:
306
+ print(f"WebRTC offer error: {e}")
307
+ import traceback
308
+ traceback.print_exc()
309
+ raise HTTPException(status_code=500, detail=f"WebRTC error: {str(e)}")
310
+
311
+ # ================ WEBSOCKET ================
312
+
313
+ @app.websocket("/ws/video")
314
+ async def websocket_endpoint(websocket: WebSocket):
315
+ await websocket.accept()
316
+ session_id = None
317
+ frame_count = 0
318
+ last_inference_time = 0
319
+ min_inference_interval = 1 / 60
320
+
321
+ try:
322
+ async with aiosqlite.connect(db_path) as db:
323
+ cursor = await db.execute("SELECT sensitivity FROM user_settings WHERE id = 1")
324
+ row = await cursor.fetchone()
325
+ sensitivity = row[0] if row else 6
326
+
327
+ while True:
328
+ data = await websocket.receive_json()
329
+
330
+ if data['type'] == 'frame':
331
+ from time import time
332
+ current_time = time()
333
+ if current_time - last_inference_time < min_inference_interval:
334
+ await websocket.send_json({'type': 'ack', 'frame_count': frame_count})
335
+ continue
336
+ last_inference_time = current_time
337
+
338
+ try:
339
+ img_data = base64.b64decode(data['image'])
340
+ nparr = np.frombuffer(img_data, np.uint8)
341
+ frame = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
342
+
343
+ if frame is None: continue
344
+ frame = cv2.resize(frame, (640, 480))
345
+
346
+ if mlp_pipeline is not None:
347
+ out = mlp_pipeline.process_frame(frame)
348
+
349
+ is_focused = out["is_focused"]
350
+ confidence = out["mlp_prob"]
351
+ metadata = {
352
+ "s_face": out["s_face"],
353
+ "s_eye": out["s_eye"],
354
+ "mar": out["mar"]
355
+ }
356
+ else:
357
+ is_focused = False
358
+ confidence = 0.0
359
+ metadata = {}
360
+
361
+ detections = []
362
+
363
+ if session_id:
364
+ await store_focus_event(session_id, is_focused, confidence, metadata)
365
+
366
+ await websocket.send_json({
367
+ 'type': 'detection',
368
+ 'focused': is_focused,
369
+ 'confidence': round(confidence, 3),
370
+ 'detections': detections,
371
+ 'frame_count': frame_count
372
+ })
373
+ frame_count += 1
374
+ except Exception as e:
375
+ print(f"Error processing frame: {e}")
376
+ await websocket.send_json({'type': 'error', 'message': str(e)})
377
+
378
+ elif data['type'] == 'start_session':
379
+ session_id = await create_session()
380
+ await websocket.send_json({'type': 'session_started', 'session_id': session_id})
381
+
382
+ elif data['type'] == 'end_session':
383
+ if session_id:
384
+ print(f"Ending session {session_id}...")
385
+ summary = await end_session(session_id)
386
+ print(f"Session summary: {summary}")
387
+ if summary:
388
+ await websocket.send_json({'type': 'session_ended', 'summary': summary})
389
+ print("Session ended message sent")
390
+ else:
391
+ print("Warning: No summary returned")
392
+ session_id = None
393
+ else:
394
+ print("Warning: end_session called but no active session_id")
395
+
396
+ except WebSocketDisconnect:
397
+ if session_id: await end_session(session_id)
398
+ except Exception as e:
399
+ if websocket.client_state.value == 1: await websocket.close()
400
+
401
+ # ================ API ENDPOINTS ================
402
+
403
+ @app.post("/api/sessions/start")
404
+ async def api_start_session():
405
+ session_id = await create_session()
406
+ return {"session_id": session_id}
407
+
408
+ @app.post("/api/sessions/end")
409
+ async def api_end_session(data: SessionEnd):
410
+ summary = await end_session(data.session_id)
411
+ if not summary: raise HTTPException(status_code=404, detail="Session not found")
412
+ return summary
413
+
414
+ @app.get("/api/sessions")
415
+ async def get_sessions(filter: str = "all", limit: int = 50, offset: int = 0):
416
+ async with aiosqlite.connect(db_path) as db:
417
+ db.row_factory = aiosqlite.Row
418
+
419
+ # NEW: If importing/exporting all, remove limit if special flag or high limit
420
+ # For simplicity: if limit is -1, return all
421
+ limit_clause = "LIMIT ? OFFSET ?"
422
+ params = []
423
+
424
+ base_query = "SELECT * FROM focus_sessions"
425
+ where_clause = ""
426
+
427
+ if filter == "today":
428
+ date_filter = datetime.now().replace(hour=0, minute=0, second=0, microsecond=0)
429
+ where_clause = " WHERE start_time >= ?"
430
+ params.append(date_filter.isoformat())
431
+ elif filter == "week":
432
+ date_filter = datetime.now() - timedelta(days=7)
433
+ where_clause = " WHERE start_time >= ?"
434
+ params.append(date_filter.isoformat())
435
+ elif filter == "month":
436
+ date_filter = datetime.now() - timedelta(days=30)
437
+ where_clause = " WHERE start_time >= ?"
438
+ params.append(date_filter.isoformat())
439
+ elif filter == "all":
440
+ # Just ensure we only get completed sessions or all sessions
441
+ where_clause = " WHERE end_time IS NOT NULL"
442
+
443
+ query = f"{base_query}{where_clause} ORDER BY start_time DESC"
444
+
445
+ # Handle Limit for Exports
446
+ if limit == -1:
447
+ # No limit clause for export
448
+ pass
449
+ else:
450
+ query += f" {limit_clause}"
451
+ params.extend([limit, offset])
452
+
453
+ cursor = await db.execute(query, tuple(params))
454
+ rows = await cursor.fetchall()
455
+ return [dict(row) for row in rows]
456
+
457
+ # --- NEW: Import Endpoint ---
458
+ @app.post("/api/import")
459
+ async def import_sessions(sessions: List[dict]):
460
+ count = 0
461
+ try:
462
+ async with aiosqlite.connect(db_path) as db:
463
+ for session in sessions:
464
+ # Use .get() to handle potential missing fields from older versions or edits
465
+ await db.execute("""
466
+ INSERT INTO focus_sessions (start_time, end_time, duration_seconds, focus_score, total_frames, focused_frames, created_at)
467
+ VALUES (?, ?, ?, ?, ?, ?, ?)
468
+ """, (
469
+ session.get('start_time'),
470
+ session.get('end_time'),
471
+ session.get('duration_seconds', 0),
472
+ session.get('focus_score', 0.0),
473
+ session.get('total_frames', 0),
474
+ session.get('focused_frames', 0),
475
+ session.get('created_at', session.get('start_time'))
476
+ ))
477
+ count += 1
478
+ await db.commit()
479
+ return {"status": "success", "count": count}
480
+ except Exception as e:
481
+ print(f"Import Error: {e}")
482
+ return {"status": "error", "message": str(e)}
483
+
484
+ # --- NEW: Clear History Endpoint ---
485
+ @app.delete("/api/history")
486
+ async def clear_history():
487
+ try:
488
+ async with aiosqlite.connect(db_path) as db:
489
+ # Delete events first (foreign key good practice)
490
+ await db.execute("DELETE FROM focus_events")
491
+ await db.execute("DELETE FROM focus_sessions")
492
+ await db.commit()
493
+ return {"status": "success", "message": "History cleared"}
494
+ except Exception as e:
495
+ return {"status": "error", "message": str(e)}
496
+
497
+ @app.get("/api/sessions/{session_id}")
498
+ async def get_session(session_id: int):
499
+ async with aiosqlite.connect(db_path) as db:
500
+ db.row_factory = aiosqlite.Row
501
+ cursor = await db.execute("SELECT * FROM focus_sessions WHERE id = ?", (session_id,))
502
+ row = await cursor.fetchone()
503
+ if not row: raise HTTPException(status_code=404, detail="Session not found")
504
+ session = dict(row)
505
+ cursor = await db.execute("SELECT * FROM focus_events WHERE session_id = ? ORDER BY timestamp", (session_id,))
506
+ events = [dict(r) for r in await cursor.fetchall()]
507
+ session['events'] = events
508
+ return session
509
+
510
+ @app.get("/api/settings")
511
+ async def get_settings():
512
+ async with aiosqlite.connect(db_path) as db:
513
+ db.row_factory = aiosqlite.Row
514
+ cursor = await db.execute("SELECT * FROM user_settings WHERE id = 1")
515
+ row = await cursor.fetchone()
516
+ if row: return dict(row)
517
+ else: return {'sensitivity': 6, 'notification_enabled': True, 'notification_threshold': 30, 'frame_rate': 30, 'model_name': 'yolov8n.pt'}
518
+
519
+ @app.put("/api/settings")
520
+ async def update_settings(settings: SettingsUpdate):
521
+ async with aiosqlite.connect(db_path) as db:
522
+ cursor = await db.execute("SELECT id FROM user_settings WHERE id = 1")
523
+ exists = await cursor.fetchone()
524
+ if not exists:
525
+ await db.execute("INSERT INTO user_settings (id, sensitivity) VALUES (1, 6)")
526
+ await db.commit()
527
+
528
+ updates = []
529
+ params = []
530
+ if settings.sensitivity is not None:
531
+ updates.append("sensitivity = ?")
532
+ params.append(max(1, min(10, settings.sensitivity)))
533
+ if settings.notification_enabled is not None:
534
+ updates.append("notification_enabled = ?")
535
+ params.append(settings.notification_enabled)
536
+ if settings.notification_threshold is not None:
537
+ updates.append("notification_threshold = ?")
538
+ params.append(max(5, min(300, settings.notification_threshold)))
539
+ if settings.frame_rate is not None:
540
+ updates.append("frame_rate = ?")
541
+ params.append(max(5, min(60, settings.frame_rate)))
542
+
543
+ if updates:
544
+ query = f"UPDATE user_settings SET {', '.join(updates)} WHERE id = 1"
545
+ await db.execute(query, params)
546
+ await db.commit()
547
+ return {"status": "success", "updated": len(updates) > 0}
548
+
549
+ @app.get("/api/stats/summary")
550
+ async def get_stats_summary():
551
+ async with aiosqlite.connect(db_path) as db:
552
+ cursor = await db.execute("SELECT COUNT(*) FROM focus_sessions WHERE end_time IS NOT NULL")
553
+ total_sessions = (await cursor.fetchone())[0]
554
+ cursor = await db.execute("SELECT SUM(duration_seconds) FROM focus_sessions WHERE end_time IS NOT NULL")
555
+ total_focus_time = (await cursor.fetchone())[0] or 0
556
+ cursor = await db.execute("SELECT AVG(focus_score) FROM focus_sessions WHERE end_time IS NOT NULL")
557
+ avg_focus_score = (await cursor.fetchone())[0] or 0.0
558
+ cursor = await db.execute("SELECT DISTINCT DATE(start_time) as session_date FROM focus_sessions WHERE end_time IS NOT NULL ORDER BY session_date DESC")
559
+ dates = [row[0] for row in await cursor.fetchall()]
560
+
561
+ streak_days = 0
562
+ if dates:
563
+ current_date = datetime.now().date()
564
+ for i, date_str in enumerate(dates):
565
+ session_date = datetime.fromisoformat(date_str).date()
566
+ expected_date = current_date - timedelta(days=i)
567
+ if session_date == expected_date: streak_days += 1
568
+ else: break
569
+ return {
570
+ 'total_sessions': total_sessions,
571
+ 'total_focus_time': int(total_focus_time),
572
+ 'avg_focus_score': round(avg_focus_score, 3),
573
+ 'streak_days': streak_days
574
+ }
575
+
576
+ @app.get("/health")
577
+ async def health_check():
578
+ return {"status": "healthy", "model_loaded": mlp_pipeline is not None, "database": os.path.exists(db_path)}
579
+
580
+ # ================ STATIC FILES (SPA SUPPORT) ================
581
+
582
+ FRONTEND_DIR = "dist" if os.path.exists("dist/index.html") else "static"
583
+
584
+ assets_path = os.path.join(FRONTEND_DIR, "assets")
585
+ if os.path.exists(assets_path):
586
+ app.mount("/assets", StaticFiles(directory=assets_path), name="assets")
587
+
588
+ @app.get("/{full_path:path}")
589
+ async def serve_react_app(full_path: str, request: Request):
590
+ if full_path.startswith("api") or full_path.startswith("ws"):
591
+ raise HTTPException(status_code=404, detail="Not Found")
592
+
593
+ file_path = os.path.join(FRONTEND_DIR, full_path)
594
+ if os.path.isfile(file_path):
595
+ return FileResponse(file_path)
596
+
597
+ index_path = os.path.join(FRONTEND_DIR, "index.html")
598
+ if os.path.exists(index_path):
599
+ return FileResponse(index_path)
600
+ else:
601
+ return {"message": "React app not found. Please run npm run build."}