Yingtao-Zheng commited on
Commit
9a472c2
·
1 Parent(s): 6941e5d

Combine all updated main code together, forming new main.py

Browse files
Files changed (3) hide show
  1. main.py +22 -6
  2. main_combined.py +0 -980
  3. main_yolo.py +0 -601
main.py CHANGED
@@ -306,7 +306,13 @@ class VideoTransformTrack(VideoStreamTrack):
306
  channel = self.get_channel()
307
  if channel and channel.readyState == "open":
308
  try:
309
- channel.send(json.dumps({"type": "detection", "focused": is_focused, "confidence": round(confidence, 3), "detections": detections}))
 
 
 
 
 
 
310
  except Exception:
311
  pass
312
 
@@ -707,8 +713,10 @@ async def websocket_endpoint(websocket: WebSocket):
707
  "type": "detection",
708
  "focused": is_focused,
709
  "confidence": round(confidence, 3),
 
710
  "model": model_name,
711
  "fc": frame_count,
 
712
  }
713
  if active_pipeline is not None:
714
  # Send detailed metrics for HUD
@@ -941,9 +949,13 @@ async def health_check():
941
 
942
  # ================ STATIC FILES (SPA SUPPORT) ================
943
 
944
- # Resolve static dir from this file so it works regardless of cwd
945
- _STATIC_DIR = Path(__file__).resolve().parent / "static"
946
- _ASSETS_DIR = _STATIC_DIR / "assets"
 
 
 
 
947
 
948
  # 1. Mount the assets folder (JS/CSS) first so /assets/* is never caught by catch-all
949
  if _ASSETS_DIR.is_dir():
@@ -958,7 +970,11 @@ async def serve_react_app(full_path: str, request: Request):
958
  if full_path.startswith("assets") or full_path.startswith("assets/"):
959
  raise HTTPException(status_code=404, detail="Not Found")
960
 
961
- index_path = _STATIC_DIR / "index.html"
 
 
 
 
962
  if index_path.is_file():
963
  return FileResponse(str(index_path))
964
- return {"message": "React app not found. Please run 'npm run build' and copy dist to static."}
 
306
  channel = self.get_channel()
307
  if channel and channel.readyState == "open":
308
  try:
309
+ channel.send(json.dumps({
310
+ "type": "detection",
311
+ "focused": is_focused,
312
+ "confidence": round(confidence, 3),
313
+ "detections": [],
314
+ "model": model_name,
315
+ }))
316
  except Exception:
317
  pass
318
 
 
713
  "type": "detection",
714
  "focused": is_focused,
715
  "confidence": round(confidence, 3),
716
+ "detections": [],
717
  "model": model_name,
718
  "fc": frame_count,
719
+ "frame_count": frame_count,
720
  }
721
  if active_pipeline is not None:
722
  # Send detailed metrics for HUD
 
949
 
950
  # ================ STATIC FILES (SPA SUPPORT) ================
951
 
952
+ # Resolve frontend dir from this file so it works regardless of cwd.
953
+ # Prefer a built `dist/` app when present, otherwise fall back to `static/`.
954
+ _BASE_DIR = Path(__file__).resolve().parent
955
+ _DIST_DIR = _BASE_DIR / "dist"
956
+ _STATIC_DIR = _BASE_DIR / "static"
957
+ _FRONTEND_DIR = _DIST_DIR if (_DIST_DIR / "index.html").is_file() else _STATIC_DIR
958
+ _ASSETS_DIR = _FRONTEND_DIR / "assets"
959
 
960
  # 1. Mount the assets folder (JS/CSS) first so /assets/* is never caught by catch-all
961
  if _ASSETS_DIR.is_dir():
 
970
  if full_path.startswith("assets") or full_path.startswith("assets/"):
971
  raise HTTPException(status_code=404, detail="Not Found")
972
 
973
+ file_path = _FRONTEND_DIR / full_path
974
+ if full_path and file_path.is_file():
975
+ return FileResponse(str(file_path))
976
+
977
+ index_path = _FRONTEND_DIR / "index.html"
978
  if index_path.is_file():
979
  return FileResponse(str(index_path))
980
+ return {"message": "React app not found. Please run 'npm run build' and copy dist to static if needed."}
main_combined.py DELETED
@@ -1,980 +0,0 @@
1
- from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException, Request
2
- from fastapi.staticfiles import StaticFiles
3
- from fastapi.responses import FileResponse
4
- from fastapi.middleware.cors import CORSMiddleware
5
- from pydantic import BaseModel
6
- from typing import Optional, List, Any
7
- import base64
8
- import cv2
9
- import numpy as np
10
- import aiosqlite
11
- import json
12
- from datetime import datetime, timedelta
13
- import math
14
- import os
15
- from pathlib import Path
16
- from typing import Callable
17
- import asyncio
18
- import concurrent.futures
19
- import threading
20
-
21
- from aiortc import RTCPeerConnection, RTCSessionDescription, VideoStreamTrack
22
- from av import VideoFrame
23
-
24
- from mediapipe.tasks.python.vision import FaceLandmarksConnections
25
- from ui.pipeline import FaceMeshPipeline, MLPPipeline, HybridFocusPipeline, XGBoostPipeline
26
- from models.face_mesh import FaceMeshDetector
27
-
28
- # ================ FACE MESH DRAWING (server-side, for WebRTC) ================
29
-
30
- _FONT = cv2.FONT_HERSHEY_SIMPLEX
31
- _CYAN = (255, 255, 0)
32
- _GREEN = (0, 255, 0)
33
- _MAGENTA = (255, 0, 255)
34
- _ORANGE = (0, 165, 255)
35
- _RED = (0, 0, 255)
36
- _WHITE = (255, 255, 255)
37
- _LIGHT_GREEN = (144, 238, 144)
38
-
39
- _TESSELATION_CONNS = [(c.start, c.end) for c in FaceLandmarksConnections.FACE_LANDMARKS_TESSELATION]
40
- _CONTOUR_CONNS = [(c.start, c.end) for c in FaceLandmarksConnections.FACE_LANDMARKS_CONTOURS]
41
- _LEFT_EYEBROW = [70, 63, 105, 66, 107, 55, 65, 52, 53, 46]
42
- _RIGHT_EYEBROW = [300, 293, 334, 296, 336, 285, 295, 282, 283, 276]
43
- _NOSE_BRIDGE = [6, 197, 195, 5, 4, 1, 19, 94, 2]
44
- _LIPS_OUTER = [61, 146, 91, 181, 84, 17, 314, 405, 321, 375, 291, 409, 270, 269, 267, 0, 37, 39, 40, 185, 61]
45
- _LIPS_INNER = [78, 95, 88, 178, 87, 14, 317, 402, 318, 324, 308, 415, 310, 311, 312, 13, 82, 81, 80, 191, 78]
46
- _LEFT_EAR_POINTS = [33, 160, 158, 133, 153, 145]
47
- _RIGHT_EAR_POINTS = [362, 385, 387, 263, 373, 380]
48
-
49
-
50
- def _lm_px(lm, idx, w, h):
51
- return (int(lm[idx, 0] * w), int(lm[idx, 1] * h))
52
-
53
-
54
- def _draw_polyline(frame, lm, indices, w, h, color, thickness):
55
- for i in range(len(indices) - 1):
56
- cv2.line(frame, _lm_px(lm, indices[i], w, h), _lm_px(lm, indices[i + 1], w, h), color, thickness, cv2.LINE_AA)
57
-
58
-
59
- def _draw_face_mesh(frame, lm, w, h):
60
- """Draw tessellation, contours, eyebrows, nose, lips, eyes, irises, gaze lines."""
61
- # Tessellation (gray triangular grid, semi-transparent)
62
- overlay = frame.copy()
63
- for s, e in _TESSELATION_CONNS:
64
- cv2.line(overlay, _lm_px(lm, s, w, h), _lm_px(lm, e, w, h), (200, 200, 200), 1, cv2.LINE_AA)
65
- cv2.addWeighted(overlay, 0.3, frame, 0.7, 0, frame)
66
- # Contours
67
- for s, e in _CONTOUR_CONNS:
68
- cv2.line(frame, _lm_px(lm, s, w, h), _lm_px(lm, e, w, h), _CYAN, 1, cv2.LINE_AA)
69
- # Eyebrows
70
- _draw_polyline(frame, lm, _LEFT_EYEBROW, w, h, _LIGHT_GREEN, 2)
71
- _draw_polyline(frame, lm, _RIGHT_EYEBROW, w, h, _LIGHT_GREEN, 2)
72
- # Nose
73
- _draw_polyline(frame, lm, _NOSE_BRIDGE, w, h, _ORANGE, 1)
74
- # Lips
75
- _draw_polyline(frame, lm, _LIPS_OUTER, w, h, _MAGENTA, 1)
76
- _draw_polyline(frame, lm, _LIPS_INNER, w, h, (200, 0, 200), 1)
77
- # Eyes
78
- left_pts = np.array([_lm_px(lm, i, w, h) for i in FaceMeshDetector.LEFT_EYE_INDICES], dtype=np.int32)
79
- cv2.polylines(frame, [left_pts], True, _GREEN, 2, cv2.LINE_AA)
80
- right_pts = np.array([_lm_px(lm, i, w, h) for i in FaceMeshDetector.RIGHT_EYE_INDICES], dtype=np.int32)
81
- cv2.polylines(frame, [right_pts], True, _GREEN, 2, cv2.LINE_AA)
82
- # EAR key points
83
- for indices in [_LEFT_EAR_POINTS, _RIGHT_EAR_POINTS]:
84
- for idx in indices:
85
- cv2.circle(frame, _lm_px(lm, idx, w, h), 3, (0, 255, 255), -1, cv2.LINE_AA)
86
- # Irises + gaze lines
87
- for iris_idx, eye_inner, eye_outer in [
88
- (FaceMeshDetector.LEFT_IRIS_INDICES, 133, 33),
89
- (FaceMeshDetector.RIGHT_IRIS_INDICES, 362, 263),
90
- ]:
91
- iris_pts = np.array([_lm_px(lm, i, w, h) for i in iris_idx], dtype=np.int32)
92
- center = iris_pts[0]
93
- if len(iris_pts) >= 5:
94
- radii = [np.linalg.norm(iris_pts[j] - center) for j in range(1, 5)]
95
- radius = max(int(np.mean(radii)), 2)
96
- cv2.circle(frame, tuple(center), radius, _MAGENTA, 2, cv2.LINE_AA)
97
- cv2.circle(frame, tuple(center), 2, _WHITE, -1, cv2.LINE_AA)
98
- eye_cx = int((lm[eye_inner, 0] + lm[eye_outer, 0]) / 2.0 * w)
99
- eye_cy = int((lm[eye_inner, 1] + lm[eye_outer, 1]) / 2.0 * h)
100
- dx, dy = center[0] - eye_cx, center[1] - eye_cy
101
- cv2.line(frame, tuple(center), (int(center[0] + dx * 3), int(center[1] + dy * 3)), _RED, 1, cv2.LINE_AA)
102
-
103
-
104
- def _draw_hud(frame, result, model_name):
105
- """Draw status bar and detail overlay matching live_demo.py."""
106
- h, w = frame.shape[:2]
107
- is_focused = result["is_focused"]
108
- status = "FOCUSED" if is_focused else "NOT FOCUSED"
109
- color = _GREEN if is_focused else _RED
110
-
111
- # Top bar
112
- cv2.rectangle(frame, (0, 0), (w, 55), (0, 0, 0), -1)
113
- cv2.putText(frame, status, (10, 28), _FONT, 0.8, color, 2, cv2.LINE_AA)
114
- cv2.putText(frame, model_name.upper(), (w - 150, 28), _FONT, 0.45, _WHITE, 1, cv2.LINE_AA)
115
-
116
- # Detail line
117
- conf = result.get("mlp_prob", result.get("raw_score", 0.0))
118
- mar_s = f" MAR:{result['mar']:.2f}" if result.get("mar") is not None else ""
119
- sf = result.get("s_face", 0)
120
- se = result.get("s_eye", 0)
121
- detail = f"conf:{conf:.2f} S_face:{sf:.2f} S_eye:{se:.2f}{mar_s}"
122
- cv2.putText(frame, detail, (10, 48), _FONT, 0.4, _WHITE, 1, cv2.LINE_AA)
123
-
124
- # Head pose (top right)
125
- if result.get("yaw") is not None:
126
- cv2.putText(frame, f"yaw:{result['yaw']:+.0f} pitch:{result['pitch']:+.0f} roll:{result['roll']:+.0f}",
127
- (w - 280, 48), _FONT, 0.4, (180, 180, 180), 1, cv2.LINE_AA)
128
-
129
- # Yawn indicator
130
- if result.get("is_yawning"):
131
- cv2.putText(frame, "YAWN", (10, 75), _FONT, 0.7, _ORANGE, 2, cv2.LINE_AA)
132
-
133
- # Landmark indices used for face mesh drawing on client (union of all groups).
134
- # Sending only these instead of all 478 saves ~60% of the landmarks payload.
135
- _MESH_INDICES = sorted(set(
136
- [10,338,297,332,284,251,389,356,454,323,361,288,397,365,379,378,400,377,152,148,176,149,150,136,172,58,132,93,234,127,162,21,54,103,67,109] # face oval
137
- + [33,7,163,144,145,153,154,155,133,173,157,158,159,160,161,246] # left eye
138
- + [362,382,381,380,374,373,390,249,263,466,388,387,386,385,384,398] # right eye
139
- + [468,469,470,471,472, 473,474,475,476,477] # irises
140
- + [70,63,105,66,107,55,65,52,53,46] # left eyebrow
141
- + [300,293,334,296,336,285,295,282,283,276] # right eyebrow
142
- + [6,197,195,5,4,1,19,94,2] # nose bridge
143
- + [61,146,91,181,84,17,314,405,321,375,291,409,270,269,267,0,37,39,40,185] # lips outer
144
- + [78,95,88,178,87,14,317,402,318,324,308,415,310,311,312,13,82,81,80,191] # lips inner
145
- + [33,160,158,133,153,145] # left EAR key points
146
- + [362,385,387,263,373,380] # right EAR key points
147
- ))
148
- # Build a lookup: original_index -> position in sparse array, so client can reconstruct.
149
- _MESH_INDEX_SET = set(_MESH_INDICES)
150
-
151
- # Initialize FastAPI app
152
- app = FastAPI(title="Focus Guard API")
153
-
154
- # Add CORS middleware
155
- app.add_middleware(
156
- CORSMiddleware,
157
- allow_origins=["*"],
158
- allow_credentials=True,
159
- allow_methods=["*"],
160
- allow_headers=["*"],
161
- )
162
-
163
- # Global variables
164
- db_path = "focus_guard.db"
165
- pcs = set()
166
- _cached_model_name = "mlp" # in-memory cache, updated via /api/settings
167
-
168
- async def _wait_for_ice_gathering(pc: RTCPeerConnection):
169
- if pc.iceGatheringState == "complete":
170
- return
171
- done = asyncio.Event()
172
-
173
- @pc.on("icegatheringstatechange")
174
- def _on_state_change():
175
- if pc.iceGatheringState == "complete":
176
- done.set()
177
-
178
- await done.wait()
179
-
180
- # ================ DATABASE MODELS ================
181
-
182
- async def init_database():
183
- """Initialize SQLite database with required tables"""
184
- async with aiosqlite.connect(db_path) as db:
185
- # FocusSessions table
186
- await db.execute("""
187
- CREATE TABLE IF NOT EXISTS focus_sessions (
188
- id INTEGER PRIMARY KEY AUTOINCREMENT,
189
- start_time TIMESTAMP NOT NULL,
190
- end_time TIMESTAMP,
191
- duration_seconds INTEGER DEFAULT 0,
192
- focus_score REAL DEFAULT 0.0,
193
- total_frames INTEGER DEFAULT 0,
194
- focused_frames INTEGER DEFAULT 0,
195
- created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
196
- )
197
- """)
198
-
199
- # FocusEvents table
200
- await db.execute("""
201
- CREATE TABLE IF NOT EXISTS focus_events (
202
- id INTEGER PRIMARY KEY AUTOINCREMENT,
203
- session_id INTEGER NOT NULL,
204
- timestamp TIMESTAMP NOT NULL,
205
- is_focused BOOLEAN NOT NULL,
206
- confidence REAL NOT NULL,
207
- detection_data TEXT,
208
- FOREIGN KEY (session_id) REFERENCES focus_sessions (id)
209
- )
210
- """)
211
-
212
- # UserSettings table
213
- await db.execute("""
214
- CREATE TABLE IF NOT EXISTS user_settings (
215
- id INTEGER PRIMARY KEY CHECK (id = 1),
216
- sensitivity INTEGER DEFAULT 6,
217
- notification_enabled BOOLEAN DEFAULT 1,
218
- notification_threshold INTEGER DEFAULT 30,
219
- frame_rate INTEGER DEFAULT 30,
220
- model_name TEXT DEFAULT 'mlp'
221
- )
222
- """)
223
-
224
- # Insert default settings if not exists
225
- await db.execute("""
226
- INSERT OR IGNORE INTO user_settings (id, sensitivity, notification_enabled, notification_threshold, frame_rate, model_name)
227
- VALUES (1, 6, 1, 30, 30, 'mlp')
228
- """)
229
-
230
- await db.commit()
231
-
232
- # ================ PYDANTIC MODELS ================
233
-
234
- class SessionCreate(BaseModel):
235
- pass
236
-
237
- class SessionEnd(BaseModel):
238
- session_id: int
239
-
240
- class SettingsUpdate(BaseModel):
241
- sensitivity: Optional[int] = None
242
- notification_enabled: Optional[bool] = None
243
- notification_threshold: Optional[int] = None
244
- frame_rate: Optional[int] = None
245
- model_name: Optional[str] = None
246
-
247
- class VideoTransformTrack(VideoStreamTrack):
248
- def __init__(self, track, session_id: int, get_channel: Callable[[], Any]):
249
- super().__init__()
250
- self.track = track
251
- self.session_id = session_id
252
- self.get_channel = get_channel
253
- self.last_inference_time = 0
254
- self.min_inference_interval = 1 / 60
255
- self.last_frame = None
256
-
257
- async def recv(self):
258
- frame = await self.track.recv()
259
- img = frame.to_ndarray(format="bgr24")
260
- if img is None:
261
- return frame
262
-
263
- # Normalize size for inference/drawing
264
- img = cv2.resize(img, (640, 480))
265
-
266
- now = datetime.now().timestamp()
267
- do_infer = (now - self.last_inference_time) >= self.min_inference_interval
268
-
269
- if do_infer:
270
- self.last_inference_time = now
271
-
272
- model_name = _cached_model_name
273
- if model_name not in pipelines or pipelines.get(model_name) is None:
274
- model_name = 'mlp'
275
- active_pipeline = pipelines.get(model_name)
276
-
277
- if active_pipeline is not None:
278
- loop = asyncio.get_event_loop()
279
- out = await loop.run_in_executor(
280
- _inference_executor,
281
- _process_frame_safe,
282
- active_pipeline,
283
- img,
284
- model_name,
285
- )
286
- is_focused = out["is_focused"]
287
- confidence = out.get("mlp_prob", out.get("raw_score", 0.0))
288
- metadata = {"s_face": out.get("s_face", 0.0), "s_eye": out.get("s_eye", 0.0), "mar": out.get("mar", 0.0), "model": model_name}
289
-
290
- # Draw face mesh + HUD on the video frame
291
- h_f, w_f = img.shape[:2]
292
- lm = out.get("landmarks")
293
- if lm is not None:
294
- _draw_face_mesh(img, lm, w_f, h_f)
295
- _draw_hud(img, out, model_name)
296
- else:
297
- is_focused = False
298
- confidence = 0.0
299
- metadata = {"model": model_name}
300
- cv2.rectangle(img, (0, 0), (img.shape[1], 55), (0, 0, 0), -1)
301
- cv2.putText(img, "NO MODEL", (10, 28), _FONT, 0.8, _RED, 2, cv2.LINE_AA)
302
-
303
- if self.session_id:
304
- await store_focus_event(self.session_id, is_focused, confidence, metadata)
305
-
306
- channel = self.get_channel()
307
- if channel and channel.readyState == "open":
308
- try:
309
- channel.send(json.dumps({
310
- "type": "detection",
311
- "focused": is_focused,
312
- "confidence": round(confidence, 3),
313
- "detections": [],
314
- "model": model_name,
315
- }))
316
- except Exception:
317
- pass
318
-
319
- self.last_frame = img
320
- elif self.last_frame is not None:
321
- img = self.last_frame
322
-
323
- new_frame = VideoFrame.from_ndarray(img, format="bgr24")
324
- new_frame.pts = frame.pts
325
- new_frame.time_base = frame.time_base
326
- return new_frame
327
-
328
- # ================ DATABASE OPERATIONS ================
329
-
330
- async def create_session():
331
- async with aiosqlite.connect(db_path) as db:
332
- cursor = await db.execute(
333
- "INSERT INTO focus_sessions (start_time) VALUES (?)",
334
- (datetime.now().isoformat(),)
335
- )
336
- await db.commit()
337
- return cursor.lastrowid
338
-
339
- async def end_session(session_id: int):
340
- async with aiosqlite.connect(db_path) as db:
341
- cursor = await db.execute(
342
- "SELECT start_time, total_frames, focused_frames FROM focus_sessions WHERE id = ?",
343
- (session_id,)
344
- )
345
- row = await cursor.fetchone()
346
-
347
- if not row:
348
- return None
349
-
350
- start_time_str, total_frames, focused_frames = row
351
- start_time = datetime.fromisoformat(start_time_str)
352
- end_time = datetime.now()
353
- duration = (end_time - start_time).total_seconds()
354
- focus_score = focused_frames / total_frames if total_frames > 0 else 0.0
355
-
356
- await db.execute("""
357
- UPDATE focus_sessions
358
- SET end_time = ?, duration_seconds = ?, focus_score = ?
359
- WHERE id = ?
360
- """, (end_time.isoformat(), int(duration), focus_score, session_id))
361
-
362
- await db.commit()
363
-
364
- return {
365
- 'session_id': session_id,
366
- 'start_time': start_time_str,
367
- 'end_time': end_time.isoformat(),
368
- 'duration_seconds': int(duration),
369
- 'focus_score': round(focus_score, 3),
370
- 'total_frames': total_frames,
371
- 'focused_frames': focused_frames
372
- }
373
-
374
- async def store_focus_event(session_id: int, is_focused: bool, confidence: float, metadata: dict):
375
- async with aiosqlite.connect(db_path) as db:
376
- await db.execute("""
377
- INSERT INTO focus_events (session_id, timestamp, is_focused, confidence, detection_data)
378
- VALUES (?, ?, ?, ?, ?)
379
- """, (session_id, datetime.now().isoformat(), is_focused, confidence, json.dumps(metadata)))
380
-
381
- await db.execute("""
382
- UPDATE focus_sessions
383
- SET total_frames = total_frames + 1,
384
- focused_frames = focused_frames + ?
385
- WHERE id = ?
386
- """, (1 if is_focused else 0, session_id))
387
- await db.commit()
388
-
389
-
390
- class _EventBuffer:
391
- """Buffer focus events in memory and flush to DB in batches to avoid per-frame DB writes."""
392
-
393
- def __init__(self, flush_interval: float = 2.0):
394
- self._buf: list = []
395
- self._lock = asyncio.Lock()
396
- self._flush_interval = flush_interval
397
- self._task: asyncio.Task | None = None
398
- self._total_frames = 0
399
- self._focused_frames = 0
400
-
401
- def start(self):
402
- if self._task is None:
403
- self._task = asyncio.create_task(self._flush_loop())
404
-
405
- async def stop(self):
406
- if self._task:
407
- self._task.cancel()
408
- try:
409
- await self._task
410
- except asyncio.CancelledError:
411
- pass
412
- self._task = None
413
- await self._flush()
414
-
415
- def add(self, session_id: int, is_focused: bool, confidence: float, metadata: dict):
416
- self._buf.append((session_id, datetime.now().isoformat(), is_focused, confidence, json.dumps(metadata)))
417
- self._total_frames += 1
418
- if is_focused:
419
- self._focused_frames += 1
420
-
421
- async def _flush_loop(self):
422
- while True:
423
- await asyncio.sleep(self._flush_interval)
424
- await self._flush()
425
-
426
- async def _flush(self):
427
- async with self._lock:
428
- if not self._buf:
429
- return
430
- batch = self._buf[:]
431
- total = self._total_frames
432
- focused = self._focused_frames
433
- self._buf.clear()
434
- self._total_frames = 0
435
- self._focused_frames = 0
436
-
437
- if not batch:
438
- return
439
-
440
- session_id = batch[0][0]
441
- try:
442
- async with aiosqlite.connect(db_path) as db:
443
- await db.executemany("""
444
- INSERT INTO focus_events (session_id, timestamp, is_focused, confidence, detection_data)
445
- VALUES (?, ?, ?, ?, ?)
446
- """, batch)
447
- await db.execute("""
448
- UPDATE focus_sessions
449
- SET total_frames = total_frames + ?,
450
- focused_frames = focused_frames + ?
451
- WHERE id = ?
452
- """, (total, focused, session_id))
453
- await db.commit()
454
- except Exception as e:
455
- print(f"[DB] Flush error: {e}")
456
-
457
- # ================ STARTUP/SHUTDOWN ================
458
-
459
- pipelines = {
460
- "geometric": None,
461
- "mlp": None,
462
- "hybrid": None,
463
- "xgboost": None,
464
- }
465
-
466
- # Thread pool for CPU-bound inference so the event loop stays responsive.
467
- _inference_executor = concurrent.futures.ThreadPoolExecutor(
468
- max_workers=4,
469
- thread_name_prefix="inference",
470
- )
471
- # One lock per pipeline so shared state (TemporalTracker, etc.) is not corrupted when
472
- # multiple frames are processed in parallel by the thread pool.
473
- _pipeline_locks = {name: threading.Lock() for name in ("geometric", "mlp", "hybrid", "xgboost")}
474
-
475
-
476
- def _process_frame_safe(pipeline, frame, model_name: str):
477
- """Run process_frame in executor with per-pipeline lock."""
478
- with _pipeline_locks[model_name]:
479
- return pipeline.process_frame(frame)
480
-
481
- @app.on_event("startup")
482
- async def startup_event():
483
- global pipelines, _cached_model_name
484
- print(" Starting Focus Guard API...")
485
- await init_database()
486
- # Load cached model name from DB
487
- async with aiosqlite.connect(db_path) as db:
488
- cursor = await db.execute("SELECT model_name FROM user_settings WHERE id = 1")
489
- row = await cursor.fetchone()
490
- if row:
491
- _cached_model_name = row[0]
492
- print("[OK] Database initialized")
493
-
494
- try:
495
- pipelines["geometric"] = FaceMeshPipeline()
496
- print("[OK] FaceMeshPipeline (geometric) loaded")
497
- except Exception as e:
498
- print(f"[WARN] FaceMeshPipeline unavailable: {e}")
499
-
500
- try:
501
- pipelines["mlp"] = MLPPipeline()
502
- print("[OK] MLPPipeline loaded")
503
- except Exception as e:
504
- print(f"[ERR] Failed to load MLPPipeline: {e}")
505
-
506
- try:
507
- pipelines["hybrid"] = HybridFocusPipeline()
508
- print("[OK] HybridFocusPipeline loaded")
509
- except Exception as e:
510
- print(f"[WARN] HybridFocusPipeline unavailable: {e}")
511
-
512
- try:
513
- pipelines["xgboost"] = XGBoostPipeline()
514
- print("[OK] XGBoostPipeline loaded")
515
- except Exception as e:
516
- print(f"[ERR] Failed to load XGBoostPipeline: {e}")
517
-
518
- @app.on_event("shutdown")
519
- async def shutdown_event():
520
- _inference_executor.shutdown(wait=False)
521
- print(" Shutting down Focus Guard API...")
522
-
523
- # ================ WEBRTC SIGNALING ================
524
-
525
- @app.post("/api/webrtc/offer")
526
- async def webrtc_offer(offer: dict):
527
- try:
528
- print(f"Received WebRTC offer")
529
-
530
- pc = RTCPeerConnection()
531
- pcs.add(pc)
532
-
533
- session_id = await create_session()
534
- print(f"Created session: {session_id}")
535
-
536
- channel_ref = {"channel": None}
537
-
538
- @pc.on("datachannel")
539
- def on_datachannel(channel):
540
- print(f"Data channel opened")
541
- channel_ref["channel"] = channel
542
-
543
- @pc.on("track")
544
- def on_track(track):
545
- print(f"Received track: {track.kind}")
546
- if track.kind == "video":
547
- local_track = VideoTransformTrack(track, session_id, lambda: channel_ref["channel"])
548
- pc.addTrack(local_track)
549
- print(f"Video track added")
550
-
551
- @track.on("ended")
552
- async def on_ended():
553
- print(f"Track ended")
554
-
555
- @pc.on("connectionstatechange")
556
- async def on_connectionstatechange():
557
- print(f"Connection state changed: {pc.connectionState}")
558
- if pc.connectionState in ("failed", "closed", "disconnected"):
559
- try:
560
- await end_session(session_id)
561
- except Exception as e:
562
- print(f"⚠Error ending session: {e}")
563
- pcs.discard(pc)
564
- await pc.close()
565
-
566
- await pc.setRemoteDescription(RTCSessionDescription(sdp=offer["sdp"], type=offer["type"]))
567
- print(f"Remote description set")
568
-
569
- answer = await pc.createAnswer()
570
- await pc.setLocalDescription(answer)
571
- print(f"Answer created")
572
-
573
- await _wait_for_ice_gathering(pc)
574
- print(f"ICE gathering complete")
575
-
576
- return {"sdp": pc.localDescription.sdp, "type": pc.localDescription.type, "session_id": session_id}
577
-
578
- except Exception as e:
579
- print(f"WebRTC offer error: {e}")
580
- import traceback
581
- traceback.print_exc()
582
- raise HTTPException(status_code=500, detail=f"WebRTC error: {str(e)}")
583
-
584
- # ================ WEBSOCKET ================
585
-
586
- @app.websocket("/ws/video")
587
- async def websocket_endpoint(websocket: WebSocket):
588
- await websocket.accept()
589
- session_id = None
590
- frame_count = 0
591
- running = True
592
- event_buffer = _EventBuffer(flush_interval=2.0)
593
-
594
- # Latest frame slot — only the most recent frame is kept, older ones are dropped.
595
- # Using a dict so nested functions can mutate without nonlocal issues.
596
- _slot = {"frame": None}
597
- _frame_ready = asyncio.Event()
598
-
599
- async def _receive_loop():
600
- """Receive messages as fast as possible. Binary = frame, text = control."""
601
- nonlocal session_id, running
602
- try:
603
- while running:
604
- msg = await websocket.receive()
605
- msg_type = msg.get("type", "")
606
-
607
- if msg_type == "websocket.disconnect":
608
- running = False
609
- _frame_ready.set()
610
- return
611
-
612
- # Binary message → JPEG frame (fast path, no base64)
613
- raw_bytes = msg.get("bytes")
614
- if raw_bytes is not None and len(raw_bytes) > 0:
615
- _slot["frame"] = raw_bytes
616
- _frame_ready.set()
617
- continue
618
-
619
- # Text message → JSON control command (or legacy base64 frame)
620
- text = msg.get("text")
621
- if not text:
622
- continue
623
- data = json.loads(text)
624
-
625
- if data["type"] == "frame":
626
- # Legacy base64 path (fallback)
627
- _slot["frame"] = base64.b64decode(data["image"])
628
- _frame_ready.set()
629
-
630
- elif data["type"] == "start_session":
631
- session_id = await create_session()
632
- event_buffer.start()
633
- for p in pipelines.values():
634
- if p is not None and hasattr(p, "reset_session"):
635
- p.reset_session()
636
- await websocket.send_json({"type": "session_started", "session_id": session_id})
637
-
638
- elif data["type"] == "end_session":
639
- if session_id:
640
- await event_buffer.stop()
641
- summary = await end_session(session_id)
642
- if summary:
643
- await websocket.send_json({"type": "session_ended", "summary": summary})
644
- session_id = None
645
- except WebSocketDisconnect:
646
- running = False
647
- _frame_ready.set()
648
- except Exception as e:
649
- print(f"[WS] receive error: {e}")
650
- running = False
651
- _frame_ready.set()
652
-
653
- async def _process_loop():
654
- """Process only the latest frame, dropping stale ones."""
655
- nonlocal frame_count, running
656
- loop = asyncio.get_event_loop()
657
- while running:
658
- await _frame_ready.wait()
659
- _frame_ready.clear()
660
- if not running:
661
- return
662
-
663
- # Grab latest frame and clear slot
664
- raw = _slot["frame"]
665
- _slot["frame"] = None
666
- if raw is None:
667
- continue
668
-
669
- try:
670
- nparr = np.frombuffer(raw, np.uint8)
671
- frame = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
672
- if frame is None:
673
- continue
674
- frame = cv2.resize(frame, (640, 480))
675
-
676
- model_name = _cached_model_name
677
- if model_name not in pipelines or pipelines.get(model_name) is None:
678
- model_name = "mlp"
679
- active_pipeline = pipelines.get(model_name)
680
-
681
- landmarks_list = None
682
- if active_pipeline is not None:
683
- out = await loop.run_in_executor(
684
- _inference_executor,
685
- _process_frame_safe,
686
- active_pipeline,
687
- frame,
688
- model_name,
689
- )
690
- is_focused = out["is_focused"]
691
- confidence = out.get("mlp_prob", out.get("raw_score", 0.0))
692
-
693
- lm = out.get("landmarks")
694
- if lm is not None:
695
- # Send all 478 landmarks as flat array for tessellation drawing
696
- landmarks_list = [
697
- [round(float(lm[i, 0]), 3), round(float(lm[i, 1]), 3)]
698
- for i in range(lm.shape[0])
699
- ]
700
-
701
- if session_id:
702
- event_buffer.add(session_id, is_focused, confidence, {
703
- "s_face": out.get("s_face", 0.0),
704
- "s_eye": out.get("s_eye", 0.0),
705
- "mar": out.get("mar", 0.0),
706
- "model": model_name,
707
- })
708
- else:
709
- is_focused = False
710
- confidence = 0.0
711
-
712
- resp = {
713
- "type": "detection",
714
- "focused": is_focused,
715
- "confidence": round(confidence, 3),
716
- "detections": [],
717
- "model": model_name,
718
- "fc": frame_count,
719
- "frame_count": frame_count,
720
- }
721
- if active_pipeline is not None:
722
- # Send detailed metrics for HUD
723
- if out.get("yaw") is not None:
724
- resp["yaw"] = round(out["yaw"], 1)
725
- resp["pitch"] = round(out["pitch"], 1)
726
- resp["roll"] = round(out["roll"], 1)
727
- if out.get("mar") is not None:
728
- resp["mar"] = round(out["mar"], 3)
729
- resp["sf"] = round(out.get("s_face", 0), 3)
730
- resp["se"] = round(out.get("s_eye", 0), 3)
731
- if landmarks_list is not None:
732
- resp["lm"] = landmarks_list
733
- await websocket.send_json(resp)
734
- frame_count += 1
735
- except Exception as e:
736
- print(f"[WS] process error: {e}")
737
-
738
- try:
739
- await asyncio.gather(_receive_loop(), _process_loop())
740
- except Exception:
741
- pass
742
- finally:
743
- running = False
744
- if session_id:
745
- await event_buffer.stop()
746
- await end_session(session_id)
747
-
748
- # ================ API ENDPOINTS ================
749
-
750
- @app.post("/api/sessions/start")
751
- async def api_start_session():
752
- session_id = await create_session()
753
- return {"session_id": session_id}
754
-
755
- @app.post("/api/sessions/end")
756
- async def api_end_session(data: SessionEnd):
757
- summary = await end_session(data.session_id)
758
- if not summary: raise HTTPException(status_code=404, detail="Session not found")
759
- return summary
760
-
761
- @app.get("/api/sessions")
762
- async def get_sessions(filter: str = "all", limit: int = 50, offset: int = 0):
763
- async with aiosqlite.connect(db_path) as db:
764
- db.row_factory = aiosqlite.Row
765
-
766
- # NEW: If importing/exporting all, remove limit if special flag or high limit
767
- # For simplicity: if limit is -1, return all
768
- limit_clause = "LIMIT ? OFFSET ?"
769
- params = []
770
-
771
- base_query = "SELECT * FROM focus_sessions"
772
- where_clause = ""
773
-
774
- if filter == "today":
775
- date_filter = datetime.now().replace(hour=0, minute=0, second=0, microsecond=0)
776
- where_clause = " WHERE start_time >= ?"
777
- params.append(date_filter.isoformat())
778
- elif filter == "week":
779
- date_filter = datetime.now() - timedelta(days=7)
780
- where_clause = " WHERE start_time >= ?"
781
- params.append(date_filter.isoformat())
782
- elif filter == "month":
783
- date_filter = datetime.now() - timedelta(days=30)
784
- where_clause = " WHERE start_time >= ?"
785
- params.append(date_filter.isoformat())
786
- elif filter == "all":
787
- # Just ensure we only get completed sessions or all sessions
788
- where_clause = " WHERE end_time IS NOT NULL"
789
-
790
- query = f"{base_query}{where_clause} ORDER BY start_time DESC"
791
-
792
- # Handle Limit for Exports
793
- if limit == -1:
794
- # No limit clause for export
795
- pass
796
- else:
797
- query += f" {limit_clause}"
798
- params.extend([limit, offset])
799
-
800
- cursor = await db.execute(query, tuple(params))
801
- rows = await cursor.fetchall()
802
- return [dict(row) for row in rows]
803
-
804
- # --- NEW: Import Endpoint ---
805
- @app.post("/api/import")
806
- async def import_sessions(sessions: List[dict]):
807
- count = 0
808
- try:
809
- async with aiosqlite.connect(db_path) as db:
810
- for session in sessions:
811
- # Use .get() to handle potential missing fields from older versions or edits
812
- await db.execute("""
813
- INSERT INTO focus_sessions (start_time, end_time, duration_seconds, focus_score, total_frames, focused_frames, created_at)
814
- VALUES (?, ?, ?, ?, ?, ?, ?)
815
- """, (
816
- session.get('start_time'),
817
- session.get('end_time'),
818
- session.get('duration_seconds', 0),
819
- session.get('focus_score', 0.0),
820
- session.get('total_frames', 0),
821
- session.get('focused_frames', 0),
822
- session.get('created_at', session.get('start_time'))
823
- ))
824
- count += 1
825
- await db.commit()
826
- return {"status": "success", "count": count}
827
- except Exception as e:
828
- print(f"Import Error: {e}")
829
- return {"status": "error", "message": str(e)}
830
-
831
- # --- NEW: Clear History Endpoint ---
832
- @app.delete("/api/history")
833
- async def clear_history():
834
- try:
835
- async with aiosqlite.connect(db_path) as db:
836
- # Delete events first (foreign key good practice)
837
- await db.execute("DELETE FROM focus_events")
838
- await db.execute("DELETE FROM focus_sessions")
839
- await db.commit()
840
- return {"status": "success", "message": "History cleared"}
841
- except Exception as e:
842
- return {"status": "error", "message": str(e)}
843
-
844
- @app.get("/api/sessions/{session_id}")
845
- async def get_session(session_id: int):
846
- async with aiosqlite.connect(db_path) as db:
847
- db.row_factory = aiosqlite.Row
848
- cursor = await db.execute("SELECT * FROM focus_sessions WHERE id = ?", (session_id,))
849
- row = await cursor.fetchone()
850
- if not row: raise HTTPException(status_code=404, detail="Session not found")
851
- session = dict(row)
852
- cursor = await db.execute("SELECT * FROM focus_events WHERE session_id = ? ORDER BY timestamp", (session_id,))
853
- events = [dict(r) for r in await cursor.fetchall()]
854
- session['events'] = events
855
- return session
856
-
857
- @app.get("/api/settings")
858
- async def get_settings():
859
- async with aiosqlite.connect(db_path) as db:
860
- db.row_factory = aiosqlite.Row
861
- cursor = await db.execute("SELECT * FROM user_settings WHERE id = 1")
862
- row = await cursor.fetchone()
863
- if row: return dict(row)
864
- else: return {'sensitivity': 6, 'notification_enabled': True, 'notification_threshold': 30, 'frame_rate': 30, 'model_name': 'mlp'}
865
-
866
- @app.put("/api/settings")
867
- async def update_settings(settings: SettingsUpdate):
868
- async with aiosqlite.connect(db_path) as db:
869
- cursor = await db.execute("SELECT id FROM user_settings WHERE id = 1")
870
- exists = await cursor.fetchone()
871
- if not exists:
872
- await db.execute("INSERT INTO user_settings (id, sensitivity) VALUES (1, 6)")
873
- await db.commit()
874
-
875
- updates = []
876
- params = []
877
- if settings.sensitivity is not None:
878
- updates.append("sensitivity = ?")
879
- params.append(max(1, min(10, settings.sensitivity)))
880
- if settings.notification_enabled is not None:
881
- updates.append("notification_enabled = ?")
882
- params.append(settings.notification_enabled)
883
- if settings.notification_threshold is not None:
884
- updates.append("notification_threshold = ?")
885
- params.append(max(5, min(300, settings.notification_threshold)))
886
- if settings.frame_rate is not None:
887
- updates.append("frame_rate = ?")
888
- params.append(max(5, min(60, settings.frame_rate)))
889
- if settings.model_name is not None and settings.model_name in pipelines and pipelines[settings.model_name] is not None:
890
- updates.append("model_name = ?")
891
- params.append(settings.model_name)
892
- global _cached_model_name
893
- _cached_model_name = settings.model_name
894
-
895
- if updates:
896
- query = f"UPDATE user_settings SET {', '.join(updates)} WHERE id = 1"
897
- await db.execute(query, params)
898
- await db.commit()
899
- return {"status": "success", "updated": len(updates) > 0}
900
-
901
- @app.get("/api/stats/summary")
902
- async def get_stats_summary():
903
- async with aiosqlite.connect(db_path) as db:
904
- cursor = await db.execute("SELECT COUNT(*) FROM focus_sessions WHERE end_time IS NOT NULL")
905
- total_sessions = (await cursor.fetchone())[0]
906
- cursor = await db.execute("SELECT SUM(duration_seconds) FROM focus_sessions WHERE end_time IS NOT NULL")
907
- total_focus_time = (await cursor.fetchone())[0] or 0
908
- cursor = await db.execute("SELECT AVG(focus_score) FROM focus_sessions WHERE end_time IS NOT NULL")
909
- avg_focus_score = (await cursor.fetchone())[0] or 0.0
910
- cursor = await db.execute("SELECT DISTINCT DATE(start_time) as session_date FROM focus_sessions WHERE end_time IS NOT NULL ORDER BY session_date DESC")
911
- dates = [row[0] for row in await cursor.fetchall()]
912
-
913
- streak_days = 0
914
- if dates:
915
- current_date = datetime.now().date()
916
- for i, date_str in enumerate(dates):
917
- session_date = datetime.fromisoformat(date_str).date()
918
- expected_date = current_date - timedelta(days=i)
919
- if session_date == expected_date: streak_days += 1
920
- else: break
921
- return {
922
- 'total_sessions': total_sessions,
923
- 'total_focus_time': int(total_focus_time),
924
- 'avg_focus_score': round(avg_focus_score, 3),
925
- 'streak_days': streak_days
926
- }
927
-
928
- @app.get("/api/models")
929
- async def get_available_models():
930
- """Return list of loaded model names and which is currently active."""
931
- available = [name for name, p in pipelines.items() if p is not None]
932
- async with aiosqlite.connect(db_path) as db:
933
- cursor = await db.execute("SELECT model_name FROM user_settings WHERE id = 1")
934
- row = await cursor.fetchone()
935
- current = row[0] if row else "mlp"
936
- if current not in available and available:
937
- current = available[0]
938
- return {"available": available, "current": current}
939
-
940
- @app.get("/api/mesh-topology")
941
- async def get_mesh_topology():
942
- """Return tessellation edge pairs for client-side face mesh drawing (cached by client)."""
943
- return {"tessellation": _TESSELATION_CONNS}
944
-
945
- @app.get("/health")
946
- async def health_check():
947
- available = [name for name, p in pipelines.items() if p is not None]
948
- return {"status": "healthy", "models_loaded": available, "database": os.path.exists(db_path)}
949
-
950
- # ================ STATIC FILES (SPA SUPPORT) ================
951
-
952
- # Resolve frontend dir from this file so it works regardless of cwd.
953
- # Prefer a built `dist/` app when present, otherwise fall back to `static/`.
954
- _BASE_DIR = Path(__file__).resolve().parent
955
- _DIST_DIR = _BASE_DIR / "dist"
956
- _STATIC_DIR = _BASE_DIR / "static"
957
- _FRONTEND_DIR = _DIST_DIR if (_DIST_DIR / "index.html").is_file() else _STATIC_DIR
958
- _ASSETS_DIR = _FRONTEND_DIR / "assets"
959
-
960
- # 1. Mount the assets folder (JS/CSS) first so /assets/* is never caught by catch-all
961
- if _ASSETS_DIR.is_dir():
962
- app.mount("/assets", StaticFiles(directory=str(_ASSETS_DIR)), name="assets")
963
-
964
- # 2. Catch-all for SPA: serve index.html for app routes, never for /assets (would break JS MIME type)
965
- @app.get("/{full_path:path}")
966
- async def serve_react_app(full_path: str, request: Request):
967
- if full_path.startswith("api") or full_path.startswith("ws"):
968
- raise HTTPException(status_code=404, detail="Not Found")
969
- # Don't serve HTML for asset paths; let them 404 so we don't break module script loading
970
- if full_path.startswith("assets") or full_path.startswith("assets/"):
971
- raise HTTPException(status_code=404, detail="Not Found")
972
-
973
- file_path = _FRONTEND_DIR / full_path
974
- if full_path and file_path.is_file():
975
- return FileResponse(str(file_path))
976
-
977
- index_path = _FRONTEND_DIR / "index.html"
978
- if index_path.is_file():
979
- return FileResponse(str(index_path))
980
- return {"message": "React app not found. Please run 'npm run build' and copy dist to static if needed."}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main_yolo.py DELETED
@@ -1,601 +0,0 @@
1
- from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException, Request
2
- from fastapi.staticfiles import StaticFiles
3
- from fastapi.responses import FileResponse
4
- from fastapi.middleware.cors import CORSMiddleware
5
- from pydantic import BaseModel
6
- from typing import Optional, List, Any
7
- import base64
8
- import cv2
9
- import numpy as np
10
- import aiosqlite
11
- import json
12
- from datetime import datetime, timedelta
13
- import math
14
- import os
15
- from pathlib import Path
16
- from typing import Callable
17
- import asyncio
18
-
19
- from aiortc import RTCPeerConnection, RTCSessionDescription, VideoStreamTrack
20
- from av import VideoFrame
21
-
22
- from ui.pipeline import MLPPipeline
23
-
24
- # Initialize FastAPI app
25
- app = FastAPI(title="Focus Guard API")
26
-
27
- # Add CORS middleware
28
- app.add_middleware(
29
- CORSMiddleware,
30
- allow_origins=["*"],
31
- allow_credentials=True,
32
- allow_methods=["*"],
33
- allow_headers=["*"],
34
- )
35
-
36
- # Global variables
37
- db_path = "focus_guard.db"
38
- pcs = set()
39
-
40
- async def _wait_for_ice_gathering(pc: RTCPeerConnection):
41
- if pc.iceGatheringState == "complete":
42
- return
43
- done = asyncio.Event()
44
-
45
- @pc.on("icegatheringstatechange")
46
- def _on_state_change():
47
- if pc.iceGatheringState == "complete":
48
- done.set()
49
-
50
- await done.wait()
51
-
52
- # ================ DATABASE MODELS ================
53
-
54
- async def init_database():
55
- """Initialize SQLite database with required tables"""
56
- async with aiosqlite.connect(db_path) as db:
57
- # FocusSessions table
58
- await db.execute("""
59
- CREATE TABLE IF NOT EXISTS focus_sessions (
60
- id INTEGER PRIMARY KEY AUTOINCREMENT,
61
- start_time TIMESTAMP NOT NULL,
62
- end_time TIMESTAMP,
63
- duration_seconds INTEGER DEFAULT 0,
64
- focus_score REAL DEFAULT 0.0,
65
- total_frames INTEGER DEFAULT 0,
66
- focused_frames INTEGER DEFAULT 0,
67
- created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
68
- )
69
- """)
70
-
71
- # FocusEvents table
72
- await db.execute("""
73
- CREATE TABLE IF NOT EXISTS focus_events (
74
- id INTEGER PRIMARY KEY AUTOINCREMENT,
75
- session_id INTEGER NOT NULL,
76
- timestamp TIMESTAMP NOT NULL,
77
- is_focused BOOLEAN NOT NULL,
78
- confidence REAL NOT NULL,
79
- detection_data TEXT,
80
- FOREIGN KEY (session_id) REFERENCES focus_sessions (id)
81
- )
82
- """)
83
-
84
- # UserSettings table
85
- await db.execute("""
86
- CREATE TABLE IF NOT EXISTS user_settings (
87
- id INTEGER PRIMARY KEY CHECK (id = 1),
88
- sensitivity INTEGER DEFAULT 6,
89
- notification_enabled BOOLEAN DEFAULT 1,
90
- notification_threshold INTEGER DEFAULT 30,
91
- frame_rate INTEGER DEFAULT 30,
92
- model_name TEXT DEFAULT 'yolov8n.pt'
93
- )
94
- """)
95
-
96
- # Insert default settings if not exists
97
- await db.execute("""
98
- INSERT OR IGNORE INTO user_settings (id, sensitivity, notification_enabled, notification_threshold, frame_rate, model_name)
99
- VALUES (1, 6, 1, 30, 30, 'yolov8n.pt')
100
- """)
101
-
102
- await db.commit()
103
-
104
- # ================ PYDANTIC MODELS ================
105
-
106
- class SessionCreate(BaseModel):
107
- pass
108
-
109
- class SessionEnd(BaseModel):
110
- session_id: int
111
-
112
- class SettingsUpdate(BaseModel):
113
- sensitivity: Optional[int] = None
114
- notification_enabled: Optional[bool] = None
115
- notification_threshold: Optional[int] = None
116
- frame_rate: Optional[int] = None
117
-
118
- class VideoTransformTrack(VideoStreamTrack):
119
- def __init__(self, track, session_id: int, get_channel: Callable[[], Any]):
120
- super().__init__()
121
- self.track = track
122
- self.session_id = session_id
123
- self.get_channel = get_channel
124
- self.last_inference_time = 0
125
- self.min_inference_interval = 1 / 60
126
- self.last_frame = None
127
-
128
- async def recv(self):
129
- frame = await self.track.recv()
130
- img = frame.to_ndarray(format="bgr24")
131
- if img is None:
132
- return frame
133
-
134
- # Normalize size for inference/drawing
135
- img = cv2.resize(img, (640, 480))
136
-
137
- now = datetime.now().timestamp()
138
- do_infer = (now - self.last_inference_time) >= self.min_inference_interval
139
-
140
- if do_infer and mlp_pipeline is not None:
141
- self.last_inference_time = now
142
- out = mlp_pipeline.process_frame(img)
143
- is_focused = out["is_focused"]
144
- confidence = out["mlp_prob"]
145
- metadata = {"s_face": out["s_face"], "s_eye": out["s_eye"], "mar": out["mar"]}
146
- detections = []
147
- status_text = "FOCUSED" if is_focused else "NOT FOCUSED"
148
- color = (0, 255, 0) if is_focused else (0, 0, 255)
149
- cv2.putText(img, status_text, (10, 25), cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2)
150
- cv2.putText(img, f"Confidence: {confidence * 100:.1f}%", (10, 50), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
151
-
152
- if self.session_id:
153
- await store_focus_event(self.session_id, is_focused, confidence, metadata)
154
-
155
- channel = self.get_channel()
156
- if channel and channel.readyState == "open":
157
- try:
158
- channel.send(json.dumps({"type": "detection", "focused": is_focused, "confidence": round(confidence, 3), "detections": detections}))
159
- except Exception:
160
- pass
161
-
162
- self.last_frame = img
163
- elif self.last_frame is not None:
164
- img = self.last_frame
165
-
166
- new_frame = VideoFrame.from_ndarray(img, format="bgr24")
167
- new_frame.pts = frame.pts
168
- new_frame.time_base = frame.time_base
169
- return new_frame
170
-
171
- # ================ DATABASE OPERATIONS ================
172
-
173
- async def create_session():
174
- async with aiosqlite.connect(db_path) as db:
175
- cursor = await db.execute(
176
- "INSERT INTO focus_sessions (start_time) VALUES (?)",
177
- (datetime.now().isoformat(),)
178
- )
179
- await db.commit()
180
- return cursor.lastrowid
181
-
182
- async def end_session(session_id: int):
183
- async with aiosqlite.connect(db_path) as db:
184
- cursor = await db.execute(
185
- "SELECT start_time, total_frames, focused_frames FROM focus_sessions WHERE id = ?",
186
- (session_id,)
187
- )
188
- row = await cursor.fetchone()
189
-
190
- if not row:
191
- return None
192
-
193
- start_time_str, total_frames, focused_frames = row
194
- start_time = datetime.fromisoformat(start_time_str)
195
- end_time = datetime.now()
196
- duration = (end_time - start_time).total_seconds()
197
- focus_score = focused_frames / total_frames if total_frames > 0 else 0.0
198
-
199
- await db.execute("""
200
- UPDATE focus_sessions
201
- SET end_time = ?, duration_seconds = ?, focus_score = ?
202
- WHERE id = ?
203
- """, (end_time.isoformat(), int(duration), focus_score, session_id))
204
-
205
- await db.commit()
206
-
207
- return {
208
- 'session_id': session_id,
209
- 'start_time': start_time_str,
210
- 'end_time': end_time.isoformat(),
211
- 'duration_seconds': int(duration),
212
- 'focus_score': round(focus_score, 3),
213
- 'total_frames': total_frames,
214
- 'focused_frames': focused_frames
215
- }
216
-
217
- async def store_focus_event(session_id: int, is_focused: bool, confidence: float, metadata: dict):
218
- async with aiosqlite.connect(db_path) as db:
219
- await db.execute("""
220
- INSERT INTO focus_events (session_id, timestamp, is_focused, confidence, detection_data)
221
- VALUES (?, ?, ?, ?, ?)
222
- """, (session_id, datetime.now().isoformat(), is_focused, confidence, json.dumps(metadata)))
223
-
224
- await db.execute(f"""
225
- UPDATE focus_sessions
226
- SET total_frames = total_frames + 1,
227
- focused_frames = focused_frames + {1 if is_focused else 0}
228
- WHERE id = ?
229
- """, (session_id,))
230
- await db.commit()
231
-
232
- # ================ STARTUP/SHUTDOWN ================
233
-
234
- mlp_pipeline = None
235
-
236
- @app.on_event("startup")
237
- async def startup_event():
238
- global mlp_pipeline
239
- print(" Starting Focus Guard API...")
240
- await init_database()
241
- print("[OK] Database initialized")
242
-
243
- mlp_pipeline = MLPPipeline()
244
- print("[OK] MLPPipeline loaded")
245
-
246
- @app.on_event("shutdown")
247
- async def shutdown_event():
248
- print(" Shutting down Focus Guard API...")
249
-
250
- # ================ WEBRTC SIGNALING ================
251
-
252
- @app.post("/api/webrtc/offer")
253
- async def webrtc_offer(offer: dict):
254
- try:
255
- print(f"Received WebRTC offer")
256
-
257
- pc = RTCPeerConnection()
258
- pcs.add(pc)
259
-
260
- session_id = await create_session()
261
- print(f"Created session: {session_id}")
262
-
263
- channel_ref = {"channel": None}
264
-
265
- @pc.on("datachannel")
266
- def on_datachannel(channel):
267
- print(f"Data channel opened")
268
- channel_ref["channel"] = channel
269
-
270
- @pc.on("track")
271
- def on_track(track):
272
- print(f"Received track: {track.kind}")
273
- if track.kind == "video":
274
- local_track = VideoTransformTrack(track, session_id, lambda: channel_ref["channel"])
275
- pc.addTrack(local_track)
276
- print(f"Video track added")
277
-
278
- @track.on("ended")
279
- async def on_ended():
280
- print(f"Track ended")
281
-
282
- @pc.on("connectionstatechange")
283
- async def on_connectionstatechange():
284
- print(f"Connection state changed: {pc.connectionState}")
285
- if pc.connectionState in ("failed", "closed", "disconnected"):
286
- try:
287
- await end_session(session_id)
288
- except Exception as e:
289
- print(f"⚠Error ending session: {e}")
290
- pcs.discard(pc)
291
- await pc.close()
292
-
293
- await pc.setRemoteDescription(RTCSessionDescription(sdp=offer["sdp"], type=offer["type"]))
294
- print(f"Remote description set")
295
-
296
- answer = await pc.createAnswer()
297
- await pc.setLocalDescription(answer)
298
- print(f"Answer created")
299
-
300
- await _wait_for_ice_gathering(pc)
301
- print(f"ICE gathering complete")
302
-
303
- return {"sdp": pc.localDescription.sdp, "type": pc.localDescription.type, "session_id": session_id}
304
-
305
- except Exception as e:
306
- print(f"WebRTC offer error: {e}")
307
- import traceback
308
- traceback.print_exc()
309
- raise HTTPException(status_code=500, detail=f"WebRTC error: {str(e)}")
310
-
311
- # ================ WEBSOCKET ================
312
-
313
- @app.websocket("/ws/video")
314
- async def websocket_endpoint(websocket: WebSocket):
315
- await websocket.accept()
316
- session_id = None
317
- frame_count = 0
318
- last_inference_time = 0
319
- min_inference_interval = 1 / 60
320
-
321
- try:
322
- async with aiosqlite.connect(db_path) as db:
323
- cursor = await db.execute("SELECT sensitivity FROM user_settings WHERE id = 1")
324
- row = await cursor.fetchone()
325
- sensitivity = row[0] if row else 6
326
-
327
- while True:
328
- data = await websocket.receive_json()
329
-
330
- if data['type'] == 'frame':
331
- from time import time
332
- current_time = time()
333
- if current_time - last_inference_time < min_inference_interval:
334
- await websocket.send_json({'type': 'ack', 'frame_count': frame_count})
335
- continue
336
- last_inference_time = current_time
337
-
338
- try:
339
- img_data = base64.b64decode(data['image'])
340
- nparr = np.frombuffer(img_data, np.uint8)
341
- frame = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
342
-
343
- if frame is None: continue
344
- frame = cv2.resize(frame, (640, 480))
345
-
346
- if mlp_pipeline is not None:
347
- out = mlp_pipeline.process_frame(frame)
348
-
349
- is_focused = out["is_focused"]
350
- confidence = out["mlp_prob"]
351
- metadata = {
352
- "s_face": out["s_face"],
353
- "s_eye": out["s_eye"],
354
- "mar": out["mar"]
355
- }
356
- else:
357
- is_focused = False
358
- confidence = 0.0
359
- metadata = {}
360
-
361
- detections = []
362
-
363
- if session_id:
364
- await store_focus_event(session_id, is_focused, confidence, metadata)
365
-
366
- await websocket.send_json({
367
- 'type': 'detection',
368
- 'focused': is_focused,
369
- 'confidence': round(confidence, 3),
370
- 'detections': detections,
371
- 'frame_count': frame_count
372
- })
373
- frame_count += 1
374
- except Exception as e:
375
- print(f"Error processing frame: {e}")
376
- await websocket.send_json({'type': 'error', 'message': str(e)})
377
-
378
- elif data['type'] == 'start_session':
379
- session_id = await create_session()
380
- await websocket.send_json({'type': 'session_started', 'session_id': session_id})
381
-
382
- elif data['type'] == 'end_session':
383
- if session_id:
384
- print(f"Ending session {session_id}...")
385
- summary = await end_session(session_id)
386
- print(f"Session summary: {summary}")
387
- if summary:
388
- await websocket.send_json({'type': 'session_ended', 'summary': summary})
389
- print("Session ended message sent")
390
- else:
391
- print("Warning: No summary returned")
392
- session_id = None
393
- else:
394
- print("Warning: end_session called but no active session_id")
395
-
396
- except WebSocketDisconnect:
397
- if session_id: await end_session(session_id)
398
- except Exception as e:
399
- if websocket.client_state.value == 1: await websocket.close()
400
-
401
- # ================ API ENDPOINTS ================
402
-
403
- @app.post("/api/sessions/start")
404
- async def api_start_session():
405
- session_id = await create_session()
406
- return {"session_id": session_id}
407
-
408
- @app.post("/api/sessions/end")
409
- async def api_end_session(data: SessionEnd):
410
- summary = await end_session(data.session_id)
411
- if not summary: raise HTTPException(status_code=404, detail="Session not found")
412
- return summary
413
-
414
- @app.get("/api/sessions")
415
- async def get_sessions(filter: str = "all", limit: int = 50, offset: int = 0):
416
- async with aiosqlite.connect(db_path) as db:
417
- db.row_factory = aiosqlite.Row
418
-
419
- # NEW: If importing/exporting all, remove limit if special flag or high limit
420
- # For simplicity: if limit is -1, return all
421
- limit_clause = "LIMIT ? OFFSET ?"
422
- params = []
423
-
424
- base_query = "SELECT * FROM focus_sessions"
425
- where_clause = ""
426
-
427
- if filter == "today":
428
- date_filter = datetime.now().replace(hour=0, minute=0, second=0, microsecond=0)
429
- where_clause = " WHERE start_time >= ?"
430
- params.append(date_filter.isoformat())
431
- elif filter == "week":
432
- date_filter = datetime.now() - timedelta(days=7)
433
- where_clause = " WHERE start_time >= ?"
434
- params.append(date_filter.isoformat())
435
- elif filter == "month":
436
- date_filter = datetime.now() - timedelta(days=30)
437
- where_clause = " WHERE start_time >= ?"
438
- params.append(date_filter.isoformat())
439
- elif filter == "all":
440
- # Just ensure we only get completed sessions or all sessions
441
- where_clause = " WHERE end_time IS NOT NULL"
442
-
443
- query = f"{base_query}{where_clause} ORDER BY start_time DESC"
444
-
445
- # Handle Limit for Exports
446
- if limit == -1:
447
- # No limit clause for export
448
- pass
449
- else:
450
- query += f" {limit_clause}"
451
- params.extend([limit, offset])
452
-
453
- cursor = await db.execute(query, tuple(params))
454
- rows = await cursor.fetchall()
455
- return [dict(row) for row in rows]
456
-
457
- # --- NEW: Import Endpoint ---
458
- @app.post("/api/import")
459
- async def import_sessions(sessions: List[dict]):
460
- count = 0
461
- try:
462
- async with aiosqlite.connect(db_path) as db:
463
- for session in sessions:
464
- # Use .get() to handle potential missing fields from older versions or edits
465
- await db.execute("""
466
- INSERT INTO focus_sessions (start_time, end_time, duration_seconds, focus_score, total_frames, focused_frames, created_at)
467
- VALUES (?, ?, ?, ?, ?, ?, ?)
468
- """, (
469
- session.get('start_time'),
470
- session.get('end_time'),
471
- session.get('duration_seconds', 0),
472
- session.get('focus_score', 0.0),
473
- session.get('total_frames', 0),
474
- session.get('focused_frames', 0),
475
- session.get('created_at', session.get('start_time'))
476
- ))
477
- count += 1
478
- await db.commit()
479
- return {"status": "success", "count": count}
480
- except Exception as e:
481
- print(f"Import Error: {e}")
482
- return {"status": "error", "message": str(e)}
483
-
484
- # --- NEW: Clear History Endpoint ---
485
- @app.delete("/api/history")
486
- async def clear_history():
487
- try:
488
- async with aiosqlite.connect(db_path) as db:
489
- # Delete events first (foreign key good practice)
490
- await db.execute("DELETE FROM focus_events")
491
- await db.execute("DELETE FROM focus_sessions")
492
- await db.commit()
493
- return {"status": "success", "message": "History cleared"}
494
- except Exception as e:
495
- return {"status": "error", "message": str(e)}
496
-
497
- @app.get("/api/sessions/{session_id}")
498
- async def get_session(session_id: int):
499
- async with aiosqlite.connect(db_path) as db:
500
- db.row_factory = aiosqlite.Row
501
- cursor = await db.execute("SELECT * FROM focus_sessions WHERE id = ?", (session_id,))
502
- row = await cursor.fetchone()
503
- if not row: raise HTTPException(status_code=404, detail="Session not found")
504
- session = dict(row)
505
- cursor = await db.execute("SELECT * FROM focus_events WHERE session_id = ? ORDER BY timestamp", (session_id,))
506
- events = [dict(r) for r in await cursor.fetchall()]
507
- session['events'] = events
508
- return session
509
-
510
- @app.get("/api/settings")
511
- async def get_settings():
512
- async with aiosqlite.connect(db_path) as db:
513
- db.row_factory = aiosqlite.Row
514
- cursor = await db.execute("SELECT * FROM user_settings WHERE id = 1")
515
- row = await cursor.fetchone()
516
- if row: return dict(row)
517
- else: return {'sensitivity': 6, 'notification_enabled': True, 'notification_threshold': 30, 'frame_rate': 30, 'model_name': 'yolov8n.pt'}
518
-
519
- @app.put("/api/settings")
520
- async def update_settings(settings: SettingsUpdate):
521
- async with aiosqlite.connect(db_path) as db:
522
- cursor = await db.execute("SELECT id FROM user_settings WHERE id = 1")
523
- exists = await cursor.fetchone()
524
- if not exists:
525
- await db.execute("INSERT INTO user_settings (id, sensitivity) VALUES (1, 6)")
526
- await db.commit()
527
-
528
- updates = []
529
- params = []
530
- if settings.sensitivity is not None:
531
- updates.append("sensitivity = ?")
532
- params.append(max(1, min(10, settings.sensitivity)))
533
- if settings.notification_enabled is not None:
534
- updates.append("notification_enabled = ?")
535
- params.append(settings.notification_enabled)
536
- if settings.notification_threshold is not None:
537
- updates.append("notification_threshold = ?")
538
- params.append(max(5, min(300, settings.notification_threshold)))
539
- if settings.frame_rate is not None:
540
- updates.append("frame_rate = ?")
541
- params.append(max(5, min(60, settings.frame_rate)))
542
-
543
- if updates:
544
- query = f"UPDATE user_settings SET {', '.join(updates)} WHERE id = 1"
545
- await db.execute(query, params)
546
- await db.commit()
547
- return {"status": "success", "updated": len(updates) > 0}
548
-
549
- @app.get("/api/stats/summary")
550
- async def get_stats_summary():
551
- async with aiosqlite.connect(db_path) as db:
552
- cursor = await db.execute("SELECT COUNT(*) FROM focus_sessions WHERE end_time IS NOT NULL")
553
- total_sessions = (await cursor.fetchone())[0]
554
- cursor = await db.execute("SELECT SUM(duration_seconds) FROM focus_sessions WHERE end_time IS NOT NULL")
555
- total_focus_time = (await cursor.fetchone())[0] or 0
556
- cursor = await db.execute("SELECT AVG(focus_score) FROM focus_sessions WHERE end_time IS NOT NULL")
557
- avg_focus_score = (await cursor.fetchone())[0] or 0.0
558
- cursor = await db.execute("SELECT DISTINCT DATE(start_time) as session_date FROM focus_sessions WHERE end_time IS NOT NULL ORDER BY session_date DESC")
559
- dates = [row[0] for row in await cursor.fetchall()]
560
-
561
- streak_days = 0
562
- if dates:
563
- current_date = datetime.now().date()
564
- for i, date_str in enumerate(dates):
565
- session_date = datetime.fromisoformat(date_str).date()
566
- expected_date = current_date - timedelta(days=i)
567
- if session_date == expected_date: streak_days += 1
568
- else: break
569
- return {
570
- 'total_sessions': total_sessions,
571
- 'total_focus_time': int(total_focus_time),
572
- 'avg_focus_score': round(avg_focus_score, 3),
573
- 'streak_days': streak_days
574
- }
575
-
576
- @app.get("/health")
577
- async def health_check():
578
- return {"status": "healthy", "model_loaded": mlp_pipeline is not None, "database": os.path.exists(db_path)}
579
-
580
- # ================ STATIC FILES (SPA SUPPORT) ================
581
-
582
- FRONTEND_DIR = "dist" if os.path.exists("dist/index.html") else "static"
583
-
584
- assets_path = os.path.join(FRONTEND_DIR, "assets")
585
- if os.path.exists(assets_path):
586
- app.mount("/assets", StaticFiles(directory=assets_path), name="assets")
587
-
588
- @app.get("/{full_path:path}")
589
- async def serve_react_app(full_path: str, request: Request):
590
- if full_path.startswith("api") or full_path.startswith("ws"):
591
- raise HTTPException(status_code=404, detail="Not Found")
592
-
593
- file_path = os.path.join(FRONTEND_DIR, full_path)
594
- if os.path.isfile(file_path):
595
- return FileResponse(file_path)
596
-
597
- index_path = os.path.join(FRONTEND_DIR, "index.html")
598
- if os.path.exists(index_path):
599
- return FileResponse(index_path)
600
- else:
601
- return {"message": "React app not found. Please run npm run build."}