Spaces:
Sleeping
Sleeping
Commit ·
9a472c2
1
Parent(s): 6941e5d
Combine all updated main code together, forming new main.py
Browse files- main.py +22 -6
- main_combined.py +0 -980
- main_yolo.py +0 -601
main.py
CHANGED
|
@@ -306,7 +306,13 @@ class VideoTransformTrack(VideoStreamTrack):
|
|
| 306 |
channel = self.get_channel()
|
| 307 |
if channel and channel.readyState == "open":
|
| 308 |
try:
|
| 309 |
-
channel.send(json.dumps({
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 310 |
except Exception:
|
| 311 |
pass
|
| 312 |
|
|
@@ -707,8 +713,10 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
| 707 |
"type": "detection",
|
| 708 |
"focused": is_focused,
|
| 709 |
"confidence": round(confidence, 3),
|
|
|
|
| 710 |
"model": model_name,
|
| 711 |
"fc": frame_count,
|
|
|
|
| 712 |
}
|
| 713 |
if active_pipeline is not None:
|
| 714 |
# Send detailed metrics for HUD
|
|
@@ -941,9 +949,13 @@ async def health_check():
|
|
| 941 |
|
| 942 |
# ================ STATIC FILES (SPA SUPPORT) ================
|
| 943 |
|
| 944 |
-
# Resolve
|
| 945 |
-
|
| 946 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 947 |
|
| 948 |
# 1. Mount the assets folder (JS/CSS) first so /assets/* is never caught by catch-all
|
| 949 |
if _ASSETS_DIR.is_dir():
|
|
@@ -958,7 +970,11 @@ async def serve_react_app(full_path: str, request: Request):
|
|
| 958 |
if full_path.startswith("assets") or full_path.startswith("assets/"):
|
| 959 |
raise HTTPException(status_code=404, detail="Not Found")
|
| 960 |
|
| 961 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 962 |
if index_path.is_file():
|
| 963 |
return FileResponse(str(index_path))
|
| 964 |
-
return {"message": "React app not found. Please run 'npm run build' and copy dist to static."}
|
|
|
|
| 306 |
channel = self.get_channel()
|
| 307 |
if channel and channel.readyState == "open":
|
| 308 |
try:
|
| 309 |
+
channel.send(json.dumps({
|
| 310 |
+
"type": "detection",
|
| 311 |
+
"focused": is_focused,
|
| 312 |
+
"confidence": round(confidence, 3),
|
| 313 |
+
"detections": [],
|
| 314 |
+
"model": model_name,
|
| 315 |
+
}))
|
| 316 |
except Exception:
|
| 317 |
pass
|
| 318 |
|
|
|
|
| 713 |
"type": "detection",
|
| 714 |
"focused": is_focused,
|
| 715 |
"confidence": round(confidence, 3),
|
| 716 |
+
"detections": [],
|
| 717 |
"model": model_name,
|
| 718 |
"fc": frame_count,
|
| 719 |
+
"frame_count": frame_count,
|
| 720 |
}
|
| 721 |
if active_pipeline is not None:
|
| 722 |
# Send detailed metrics for HUD
|
|
|
|
| 949 |
|
| 950 |
# ================ STATIC FILES (SPA SUPPORT) ================
|
| 951 |
|
| 952 |
+
# Resolve frontend dir from this file so it works regardless of cwd.
|
| 953 |
+
# Prefer a built `dist/` app when present, otherwise fall back to `static/`.
|
| 954 |
+
_BASE_DIR = Path(__file__).resolve().parent
|
| 955 |
+
_DIST_DIR = _BASE_DIR / "dist"
|
| 956 |
+
_STATIC_DIR = _BASE_DIR / "static"
|
| 957 |
+
_FRONTEND_DIR = _DIST_DIR if (_DIST_DIR / "index.html").is_file() else _STATIC_DIR
|
| 958 |
+
_ASSETS_DIR = _FRONTEND_DIR / "assets"
|
| 959 |
|
| 960 |
# 1. Mount the assets folder (JS/CSS) first so /assets/* is never caught by catch-all
|
| 961 |
if _ASSETS_DIR.is_dir():
|
|
|
|
| 970 |
if full_path.startswith("assets") or full_path.startswith("assets/"):
|
| 971 |
raise HTTPException(status_code=404, detail="Not Found")
|
| 972 |
|
| 973 |
+
file_path = _FRONTEND_DIR / full_path
|
| 974 |
+
if full_path and file_path.is_file():
|
| 975 |
+
return FileResponse(str(file_path))
|
| 976 |
+
|
| 977 |
+
index_path = _FRONTEND_DIR / "index.html"
|
| 978 |
if index_path.is_file():
|
| 979 |
return FileResponse(str(index_path))
|
| 980 |
+
return {"message": "React app not found. Please run 'npm run build' and copy dist to static if needed."}
|
main_combined.py
DELETED
|
@@ -1,980 +0,0 @@
|
|
| 1 |
-
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException, Request
|
| 2 |
-
from fastapi.staticfiles import StaticFiles
|
| 3 |
-
from fastapi.responses import FileResponse
|
| 4 |
-
from fastapi.middleware.cors import CORSMiddleware
|
| 5 |
-
from pydantic import BaseModel
|
| 6 |
-
from typing import Optional, List, Any
|
| 7 |
-
import base64
|
| 8 |
-
import cv2
|
| 9 |
-
import numpy as np
|
| 10 |
-
import aiosqlite
|
| 11 |
-
import json
|
| 12 |
-
from datetime import datetime, timedelta
|
| 13 |
-
import math
|
| 14 |
-
import os
|
| 15 |
-
from pathlib import Path
|
| 16 |
-
from typing import Callable
|
| 17 |
-
import asyncio
|
| 18 |
-
import concurrent.futures
|
| 19 |
-
import threading
|
| 20 |
-
|
| 21 |
-
from aiortc import RTCPeerConnection, RTCSessionDescription, VideoStreamTrack
|
| 22 |
-
from av import VideoFrame
|
| 23 |
-
|
| 24 |
-
from mediapipe.tasks.python.vision import FaceLandmarksConnections
|
| 25 |
-
from ui.pipeline import FaceMeshPipeline, MLPPipeline, HybridFocusPipeline, XGBoostPipeline
|
| 26 |
-
from models.face_mesh import FaceMeshDetector
|
| 27 |
-
|
| 28 |
-
# ================ FACE MESH DRAWING (server-side, for WebRTC) ================
|
| 29 |
-
|
| 30 |
-
_FONT = cv2.FONT_HERSHEY_SIMPLEX
|
| 31 |
-
_CYAN = (255, 255, 0)
|
| 32 |
-
_GREEN = (0, 255, 0)
|
| 33 |
-
_MAGENTA = (255, 0, 255)
|
| 34 |
-
_ORANGE = (0, 165, 255)
|
| 35 |
-
_RED = (0, 0, 255)
|
| 36 |
-
_WHITE = (255, 255, 255)
|
| 37 |
-
_LIGHT_GREEN = (144, 238, 144)
|
| 38 |
-
|
| 39 |
-
_TESSELATION_CONNS = [(c.start, c.end) for c in FaceLandmarksConnections.FACE_LANDMARKS_TESSELATION]
|
| 40 |
-
_CONTOUR_CONNS = [(c.start, c.end) for c in FaceLandmarksConnections.FACE_LANDMARKS_CONTOURS]
|
| 41 |
-
_LEFT_EYEBROW = [70, 63, 105, 66, 107, 55, 65, 52, 53, 46]
|
| 42 |
-
_RIGHT_EYEBROW = [300, 293, 334, 296, 336, 285, 295, 282, 283, 276]
|
| 43 |
-
_NOSE_BRIDGE = [6, 197, 195, 5, 4, 1, 19, 94, 2]
|
| 44 |
-
_LIPS_OUTER = [61, 146, 91, 181, 84, 17, 314, 405, 321, 375, 291, 409, 270, 269, 267, 0, 37, 39, 40, 185, 61]
|
| 45 |
-
_LIPS_INNER = [78, 95, 88, 178, 87, 14, 317, 402, 318, 324, 308, 415, 310, 311, 312, 13, 82, 81, 80, 191, 78]
|
| 46 |
-
_LEFT_EAR_POINTS = [33, 160, 158, 133, 153, 145]
|
| 47 |
-
_RIGHT_EAR_POINTS = [362, 385, 387, 263, 373, 380]
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
def _lm_px(lm, idx, w, h):
|
| 51 |
-
return (int(lm[idx, 0] * w), int(lm[idx, 1] * h))
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
def _draw_polyline(frame, lm, indices, w, h, color, thickness):
|
| 55 |
-
for i in range(len(indices) - 1):
|
| 56 |
-
cv2.line(frame, _lm_px(lm, indices[i], w, h), _lm_px(lm, indices[i + 1], w, h), color, thickness, cv2.LINE_AA)
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
def _draw_face_mesh(frame, lm, w, h):
|
| 60 |
-
"""Draw tessellation, contours, eyebrows, nose, lips, eyes, irises, gaze lines."""
|
| 61 |
-
# Tessellation (gray triangular grid, semi-transparent)
|
| 62 |
-
overlay = frame.copy()
|
| 63 |
-
for s, e in _TESSELATION_CONNS:
|
| 64 |
-
cv2.line(overlay, _lm_px(lm, s, w, h), _lm_px(lm, e, w, h), (200, 200, 200), 1, cv2.LINE_AA)
|
| 65 |
-
cv2.addWeighted(overlay, 0.3, frame, 0.7, 0, frame)
|
| 66 |
-
# Contours
|
| 67 |
-
for s, e in _CONTOUR_CONNS:
|
| 68 |
-
cv2.line(frame, _lm_px(lm, s, w, h), _lm_px(lm, e, w, h), _CYAN, 1, cv2.LINE_AA)
|
| 69 |
-
# Eyebrows
|
| 70 |
-
_draw_polyline(frame, lm, _LEFT_EYEBROW, w, h, _LIGHT_GREEN, 2)
|
| 71 |
-
_draw_polyline(frame, lm, _RIGHT_EYEBROW, w, h, _LIGHT_GREEN, 2)
|
| 72 |
-
# Nose
|
| 73 |
-
_draw_polyline(frame, lm, _NOSE_BRIDGE, w, h, _ORANGE, 1)
|
| 74 |
-
# Lips
|
| 75 |
-
_draw_polyline(frame, lm, _LIPS_OUTER, w, h, _MAGENTA, 1)
|
| 76 |
-
_draw_polyline(frame, lm, _LIPS_INNER, w, h, (200, 0, 200), 1)
|
| 77 |
-
# Eyes
|
| 78 |
-
left_pts = np.array([_lm_px(lm, i, w, h) for i in FaceMeshDetector.LEFT_EYE_INDICES], dtype=np.int32)
|
| 79 |
-
cv2.polylines(frame, [left_pts], True, _GREEN, 2, cv2.LINE_AA)
|
| 80 |
-
right_pts = np.array([_lm_px(lm, i, w, h) for i in FaceMeshDetector.RIGHT_EYE_INDICES], dtype=np.int32)
|
| 81 |
-
cv2.polylines(frame, [right_pts], True, _GREEN, 2, cv2.LINE_AA)
|
| 82 |
-
# EAR key points
|
| 83 |
-
for indices in [_LEFT_EAR_POINTS, _RIGHT_EAR_POINTS]:
|
| 84 |
-
for idx in indices:
|
| 85 |
-
cv2.circle(frame, _lm_px(lm, idx, w, h), 3, (0, 255, 255), -1, cv2.LINE_AA)
|
| 86 |
-
# Irises + gaze lines
|
| 87 |
-
for iris_idx, eye_inner, eye_outer in [
|
| 88 |
-
(FaceMeshDetector.LEFT_IRIS_INDICES, 133, 33),
|
| 89 |
-
(FaceMeshDetector.RIGHT_IRIS_INDICES, 362, 263),
|
| 90 |
-
]:
|
| 91 |
-
iris_pts = np.array([_lm_px(lm, i, w, h) for i in iris_idx], dtype=np.int32)
|
| 92 |
-
center = iris_pts[0]
|
| 93 |
-
if len(iris_pts) >= 5:
|
| 94 |
-
radii = [np.linalg.norm(iris_pts[j] - center) for j in range(1, 5)]
|
| 95 |
-
radius = max(int(np.mean(radii)), 2)
|
| 96 |
-
cv2.circle(frame, tuple(center), radius, _MAGENTA, 2, cv2.LINE_AA)
|
| 97 |
-
cv2.circle(frame, tuple(center), 2, _WHITE, -1, cv2.LINE_AA)
|
| 98 |
-
eye_cx = int((lm[eye_inner, 0] + lm[eye_outer, 0]) / 2.0 * w)
|
| 99 |
-
eye_cy = int((lm[eye_inner, 1] + lm[eye_outer, 1]) / 2.0 * h)
|
| 100 |
-
dx, dy = center[0] - eye_cx, center[1] - eye_cy
|
| 101 |
-
cv2.line(frame, tuple(center), (int(center[0] + dx * 3), int(center[1] + dy * 3)), _RED, 1, cv2.LINE_AA)
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
def _draw_hud(frame, result, model_name):
|
| 105 |
-
"""Draw status bar and detail overlay matching live_demo.py."""
|
| 106 |
-
h, w = frame.shape[:2]
|
| 107 |
-
is_focused = result["is_focused"]
|
| 108 |
-
status = "FOCUSED" if is_focused else "NOT FOCUSED"
|
| 109 |
-
color = _GREEN if is_focused else _RED
|
| 110 |
-
|
| 111 |
-
# Top bar
|
| 112 |
-
cv2.rectangle(frame, (0, 0), (w, 55), (0, 0, 0), -1)
|
| 113 |
-
cv2.putText(frame, status, (10, 28), _FONT, 0.8, color, 2, cv2.LINE_AA)
|
| 114 |
-
cv2.putText(frame, model_name.upper(), (w - 150, 28), _FONT, 0.45, _WHITE, 1, cv2.LINE_AA)
|
| 115 |
-
|
| 116 |
-
# Detail line
|
| 117 |
-
conf = result.get("mlp_prob", result.get("raw_score", 0.0))
|
| 118 |
-
mar_s = f" MAR:{result['mar']:.2f}" if result.get("mar") is not None else ""
|
| 119 |
-
sf = result.get("s_face", 0)
|
| 120 |
-
se = result.get("s_eye", 0)
|
| 121 |
-
detail = f"conf:{conf:.2f} S_face:{sf:.2f} S_eye:{se:.2f}{mar_s}"
|
| 122 |
-
cv2.putText(frame, detail, (10, 48), _FONT, 0.4, _WHITE, 1, cv2.LINE_AA)
|
| 123 |
-
|
| 124 |
-
# Head pose (top right)
|
| 125 |
-
if result.get("yaw") is not None:
|
| 126 |
-
cv2.putText(frame, f"yaw:{result['yaw']:+.0f} pitch:{result['pitch']:+.0f} roll:{result['roll']:+.0f}",
|
| 127 |
-
(w - 280, 48), _FONT, 0.4, (180, 180, 180), 1, cv2.LINE_AA)
|
| 128 |
-
|
| 129 |
-
# Yawn indicator
|
| 130 |
-
if result.get("is_yawning"):
|
| 131 |
-
cv2.putText(frame, "YAWN", (10, 75), _FONT, 0.7, _ORANGE, 2, cv2.LINE_AA)
|
| 132 |
-
|
| 133 |
-
# Landmark indices used for face mesh drawing on client (union of all groups).
|
| 134 |
-
# Sending only these instead of all 478 saves ~60% of the landmarks payload.
|
| 135 |
-
_MESH_INDICES = sorted(set(
|
| 136 |
-
[10,338,297,332,284,251,389,356,454,323,361,288,397,365,379,378,400,377,152,148,176,149,150,136,172,58,132,93,234,127,162,21,54,103,67,109] # face oval
|
| 137 |
-
+ [33,7,163,144,145,153,154,155,133,173,157,158,159,160,161,246] # left eye
|
| 138 |
-
+ [362,382,381,380,374,373,390,249,263,466,388,387,386,385,384,398] # right eye
|
| 139 |
-
+ [468,469,470,471,472, 473,474,475,476,477] # irises
|
| 140 |
-
+ [70,63,105,66,107,55,65,52,53,46] # left eyebrow
|
| 141 |
-
+ [300,293,334,296,336,285,295,282,283,276] # right eyebrow
|
| 142 |
-
+ [6,197,195,5,4,1,19,94,2] # nose bridge
|
| 143 |
-
+ [61,146,91,181,84,17,314,405,321,375,291,409,270,269,267,0,37,39,40,185] # lips outer
|
| 144 |
-
+ [78,95,88,178,87,14,317,402,318,324,308,415,310,311,312,13,82,81,80,191] # lips inner
|
| 145 |
-
+ [33,160,158,133,153,145] # left EAR key points
|
| 146 |
-
+ [362,385,387,263,373,380] # right EAR key points
|
| 147 |
-
))
|
| 148 |
-
# Build a lookup: original_index -> position in sparse array, so client can reconstruct.
|
| 149 |
-
_MESH_INDEX_SET = set(_MESH_INDICES)
|
| 150 |
-
|
| 151 |
-
# Initialize FastAPI app
|
| 152 |
-
app = FastAPI(title="Focus Guard API")
|
| 153 |
-
|
| 154 |
-
# Add CORS middleware
|
| 155 |
-
app.add_middleware(
|
| 156 |
-
CORSMiddleware,
|
| 157 |
-
allow_origins=["*"],
|
| 158 |
-
allow_credentials=True,
|
| 159 |
-
allow_methods=["*"],
|
| 160 |
-
allow_headers=["*"],
|
| 161 |
-
)
|
| 162 |
-
|
| 163 |
-
# Global variables
|
| 164 |
-
db_path = "focus_guard.db"
|
| 165 |
-
pcs = set()
|
| 166 |
-
_cached_model_name = "mlp" # in-memory cache, updated via /api/settings
|
| 167 |
-
|
| 168 |
-
async def _wait_for_ice_gathering(pc: RTCPeerConnection):
|
| 169 |
-
if pc.iceGatheringState == "complete":
|
| 170 |
-
return
|
| 171 |
-
done = asyncio.Event()
|
| 172 |
-
|
| 173 |
-
@pc.on("icegatheringstatechange")
|
| 174 |
-
def _on_state_change():
|
| 175 |
-
if pc.iceGatheringState == "complete":
|
| 176 |
-
done.set()
|
| 177 |
-
|
| 178 |
-
await done.wait()
|
| 179 |
-
|
| 180 |
-
# ================ DATABASE MODELS ================
|
| 181 |
-
|
| 182 |
-
async def init_database():
|
| 183 |
-
"""Initialize SQLite database with required tables"""
|
| 184 |
-
async with aiosqlite.connect(db_path) as db:
|
| 185 |
-
# FocusSessions table
|
| 186 |
-
await db.execute("""
|
| 187 |
-
CREATE TABLE IF NOT EXISTS focus_sessions (
|
| 188 |
-
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 189 |
-
start_time TIMESTAMP NOT NULL,
|
| 190 |
-
end_time TIMESTAMP,
|
| 191 |
-
duration_seconds INTEGER DEFAULT 0,
|
| 192 |
-
focus_score REAL DEFAULT 0.0,
|
| 193 |
-
total_frames INTEGER DEFAULT 0,
|
| 194 |
-
focused_frames INTEGER DEFAULT 0,
|
| 195 |
-
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
| 196 |
-
)
|
| 197 |
-
""")
|
| 198 |
-
|
| 199 |
-
# FocusEvents table
|
| 200 |
-
await db.execute("""
|
| 201 |
-
CREATE TABLE IF NOT EXISTS focus_events (
|
| 202 |
-
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 203 |
-
session_id INTEGER NOT NULL,
|
| 204 |
-
timestamp TIMESTAMP NOT NULL,
|
| 205 |
-
is_focused BOOLEAN NOT NULL,
|
| 206 |
-
confidence REAL NOT NULL,
|
| 207 |
-
detection_data TEXT,
|
| 208 |
-
FOREIGN KEY (session_id) REFERENCES focus_sessions (id)
|
| 209 |
-
)
|
| 210 |
-
""")
|
| 211 |
-
|
| 212 |
-
# UserSettings table
|
| 213 |
-
await db.execute("""
|
| 214 |
-
CREATE TABLE IF NOT EXISTS user_settings (
|
| 215 |
-
id INTEGER PRIMARY KEY CHECK (id = 1),
|
| 216 |
-
sensitivity INTEGER DEFAULT 6,
|
| 217 |
-
notification_enabled BOOLEAN DEFAULT 1,
|
| 218 |
-
notification_threshold INTEGER DEFAULT 30,
|
| 219 |
-
frame_rate INTEGER DEFAULT 30,
|
| 220 |
-
model_name TEXT DEFAULT 'mlp'
|
| 221 |
-
)
|
| 222 |
-
""")
|
| 223 |
-
|
| 224 |
-
# Insert default settings if not exists
|
| 225 |
-
await db.execute("""
|
| 226 |
-
INSERT OR IGNORE INTO user_settings (id, sensitivity, notification_enabled, notification_threshold, frame_rate, model_name)
|
| 227 |
-
VALUES (1, 6, 1, 30, 30, 'mlp')
|
| 228 |
-
""")
|
| 229 |
-
|
| 230 |
-
await db.commit()
|
| 231 |
-
|
| 232 |
-
# ================ PYDANTIC MODELS ================
|
| 233 |
-
|
| 234 |
-
class SessionCreate(BaseModel):
|
| 235 |
-
pass
|
| 236 |
-
|
| 237 |
-
class SessionEnd(BaseModel):
|
| 238 |
-
session_id: int
|
| 239 |
-
|
| 240 |
-
class SettingsUpdate(BaseModel):
|
| 241 |
-
sensitivity: Optional[int] = None
|
| 242 |
-
notification_enabled: Optional[bool] = None
|
| 243 |
-
notification_threshold: Optional[int] = None
|
| 244 |
-
frame_rate: Optional[int] = None
|
| 245 |
-
model_name: Optional[str] = None
|
| 246 |
-
|
| 247 |
-
class VideoTransformTrack(VideoStreamTrack):
|
| 248 |
-
def __init__(self, track, session_id: int, get_channel: Callable[[], Any]):
|
| 249 |
-
super().__init__()
|
| 250 |
-
self.track = track
|
| 251 |
-
self.session_id = session_id
|
| 252 |
-
self.get_channel = get_channel
|
| 253 |
-
self.last_inference_time = 0
|
| 254 |
-
self.min_inference_interval = 1 / 60
|
| 255 |
-
self.last_frame = None
|
| 256 |
-
|
| 257 |
-
async def recv(self):
|
| 258 |
-
frame = await self.track.recv()
|
| 259 |
-
img = frame.to_ndarray(format="bgr24")
|
| 260 |
-
if img is None:
|
| 261 |
-
return frame
|
| 262 |
-
|
| 263 |
-
# Normalize size for inference/drawing
|
| 264 |
-
img = cv2.resize(img, (640, 480))
|
| 265 |
-
|
| 266 |
-
now = datetime.now().timestamp()
|
| 267 |
-
do_infer = (now - self.last_inference_time) >= self.min_inference_interval
|
| 268 |
-
|
| 269 |
-
if do_infer:
|
| 270 |
-
self.last_inference_time = now
|
| 271 |
-
|
| 272 |
-
model_name = _cached_model_name
|
| 273 |
-
if model_name not in pipelines or pipelines.get(model_name) is None:
|
| 274 |
-
model_name = 'mlp'
|
| 275 |
-
active_pipeline = pipelines.get(model_name)
|
| 276 |
-
|
| 277 |
-
if active_pipeline is not None:
|
| 278 |
-
loop = asyncio.get_event_loop()
|
| 279 |
-
out = await loop.run_in_executor(
|
| 280 |
-
_inference_executor,
|
| 281 |
-
_process_frame_safe,
|
| 282 |
-
active_pipeline,
|
| 283 |
-
img,
|
| 284 |
-
model_name,
|
| 285 |
-
)
|
| 286 |
-
is_focused = out["is_focused"]
|
| 287 |
-
confidence = out.get("mlp_prob", out.get("raw_score", 0.0))
|
| 288 |
-
metadata = {"s_face": out.get("s_face", 0.0), "s_eye": out.get("s_eye", 0.0), "mar": out.get("mar", 0.0), "model": model_name}
|
| 289 |
-
|
| 290 |
-
# Draw face mesh + HUD on the video frame
|
| 291 |
-
h_f, w_f = img.shape[:2]
|
| 292 |
-
lm = out.get("landmarks")
|
| 293 |
-
if lm is not None:
|
| 294 |
-
_draw_face_mesh(img, lm, w_f, h_f)
|
| 295 |
-
_draw_hud(img, out, model_name)
|
| 296 |
-
else:
|
| 297 |
-
is_focused = False
|
| 298 |
-
confidence = 0.0
|
| 299 |
-
metadata = {"model": model_name}
|
| 300 |
-
cv2.rectangle(img, (0, 0), (img.shape[1], 55), (0, 0, 0), -1)
|
| 301 |
-
cv2.putText(img, "NO MODEL", (10, 28), _FONT, 0.8, _RED, 2, cv2.LINE_AA)
|
| 302 |
-
|
| 303 |
-
if self.session_id:
|
| 304 |
-
await store_focus_event(self.session_id, is_focused, confidence, metadata)
|
| 305 |
-
|
| 306 |
-
channel = self.get_channel()
|
| 307 |
-
if channel and channel.readyState == "open":
|
| 308 |
-
try:
|
| 309 |
-
channel.send(json.dumps({
|
| 310 |
-
"type": "detection",
|
| 311 |
-
"focused": is_focused,
|
| 312 |
-
"confidence": round(confidence, 3),
|
| 313 |
-
"detections": [],
|
| 314 |
-
"model": model_name,
|
| 315 |
-
}))
|
| 316 |
-
except Exception:
|
| 317 |
-
pass
|
| 318 |
-
|
| 319 |
-
self.last_frame = img
|
| 320 |
-
elif self.last_frame is not None:
|
| 321 |
-
img = self.last_frame
|
| 322 |
-
|
| 323 |
-
new_frame = VideoFrame.from_ndarray(img, format="bgr24")
|
| 324 |
-
new_frame.pts = frame.pts
|
| 325 |
-
new_frame.time_base = frame.time_base
|
| 326 |
-
return new_frame
|
| 327 |
-
|
| 328 |
-
# ================ DATABASE OPERATIONS ================
|
| 329 |
-
|
| 330 |
-
async def create_session():
|
| 331 |
-
async with aiosqlite.connect(db_path) as db:
|
| 332 |
-
cursor = await db.execute(
|
| 333 |
-
"INSERT INTO focus_sessions (start_time) VALUES (?)",
|
| 334 |
-
(datetime.now().isoformat(),)
|
| 335 |
-
)
|
| 336 |
-
await db.commit()
|
| 337 |
-
return cursor.lastrowid
|
| 338 |
-
|
| 339 |
-
async def end_session(session_id: int):
|
| 340 |
-
async with aiosqlite.connect(db_path) as db:
|
| 341 |
-
cursor = await db.execute(
|
| 342 |
-
"SELECT start_time, total_frames, focused_frames FROM focus_sessions WHERE id = ?",
|
| 343 |
-
(session_id,)
|
| 344 |
-
)
|
| 345 |
-
row = await cursor.fetchone()
|
| 346 |
-
|
| 347 |
-
if not row:
|
| 348 |
-
return None
|
| 349 |
-
|
| 350 |
-
start_time_str, total_frames, focused_frames = row
|
| 351 |
-
start_time = datetime.fromisoformat(start_time_str)
|
| 352 |
-
end_time = datetime.now()
|
| 353 |
-
duration = (end_time - start_time).total_seconds()
|
| 354 |
-
focus_score = focused_frames / total_frames if total_frames > 0 else 0.0
|
| 355 |
-
|
| 356 |
-
await db.execute("""
|
| 357 |
-
UPDATE focus_sessions
|
| 358 |
-
SET end_time = ?, duration_seconds = ?, focus_score = ?
|
| 359 |
-
WHERE id = ?
|
| 360 |
-
""", (end_time.isoformat(), int(duration), focus_score, session_id))
|
| 361 |
-
|
| 362 |
-
await db.commit()
|
| 363 |
-
|
| 364 |
-
return {
|
| 365 |
-
'session_id': session_id,
|
| 366 |
-
'start_time': start_time_str,
|
| 367 |
-
'end_time': end_time.isoformat(),
|
| 368 |
-
'duration_seconds': int(duration),
|
| 369 |
-
'focus_score': round(focus_score, 3),
|
| 370 |
-
'total_frames': total_frames,
|
| 371 |
-
'focused_frames': focused_frames
|
| 372 |
-
}
|
| 373 |
-
|
| 374 |
-
async def store_focus_event(session_id: int, is_focused: bool, confidence: float, metadata: dict):
|
| 375 |
-
async with aiosqlite.connect(db_path) as db:
|
| 376 |
-
await db.execute("""
|
| 377 |
-
INSERT INTO focus_events (session_id, timestamp, is_focused, confidence, detection_data)
|
| 378 |
-
VALUES (?, ?, ?, ?, ?)
|
| 379 |
-
""", (session_id, datetime.now().isoformat(), is_focused, confidence, json.dumps(metadata)))
|
| 380 |
-
|
| 381 |
-
await db.execute("""
|
| 382 |
-
UPDATE focus_sessions
|
| 383 |
-
SET total_frames = total_frames + 1,
|
| 384 |
-
focused_frames = focused_frames + ?
|
| 385 |
-
WHERE id = ?
|
| 386 |
-
""", (1 if is_focused else 0, session_id))
|
| 387 |
-
await db.commit()
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
class _EventBuffer:
|
| 391 |
-
"""Buffer focus events in memory and flush to DB in batches to avoid per-frame DB writes."""
|
| 392 |
-
|
| 393 |
-
def __init__(self, flush_interval: float = 2.0):
|
| 394 |
-
self._buf: list = []
|
| 395 |
-
self._lock = asyncio.Lock()
|
| 396 |
-
self._flush_interval = flush_interval
|
| 397 |
-
self._task: asyncio.Task | None = None
|
| 398 |
-
self._total_frames = 0
|
| 399 |
-
self._focused_frames = 0
|
| 400 |
-
|
| 401 |
-
def start(self):
|
| 402 |
-
if self._task is None:
|
| 403 |
-
self._task = asyncio.create_task(self._flush_loop())
|
| 404 |
-
|
| 405 |
-
async def stop(self):
|
| 406 |
-
if self._task:
|
| 407 |
-
self._task.cancel()
|
| 408 |
-
try:
|
| 409 |
-
await self._task
|
| 410 |
-
except asyncio.CancelledError:
|
| 411 |
-
pass
|
| 412 |
-
self._task = None
|
| 413 |
-
await self._flush()
|
| 414 |
-
|
| 415 |
-
def add(self, session_id: int, is_focused: bool, confidence: float, metadata: dict):
|
| 416 |
-
self._buf.append((session_id, datetime.now().isoformat(), is_focused, confidence, json.dumps(metadata)))
|
| 417 |
-
self._total_frames += 1
|
| 418 |
-
if is_focused:
|
| 419 |
-
self._focused_frames += 1
|
| 420 |
-
|
| 421 |
-
async def _flush_loop(self):
|
| 422 |
-
while True:
|
| 423 |
-
await asyncio.sleep(self._flush_interval)
|
| 424 |
-
await self._flush()
|
| 425 |
-
|
| 426 |
-
async def _flush(self):
|
| 427 |
-
async with self._lock:
|
| 428 |
-
if not self._buf:
|
| 429 |
-
return
|
| 430 |
-
batch = self._buf[:]
|
| 431 |
-
total = self._total_frames
|
| 432 |
-
focused = self._focused_frames
|
| 433 |
-
self._buf.clear()
|
| 434 |
-
self._total_frames = 0
|
| 435 |
-
self._focused_frames = 0
|
| 436 |
-
|
| 437 |
-
if not batch:
|
| 438 |
-
return
|
| 439 |
-
|
| 440 |
-
session_id = batch[0][0]
|
| 441 |
-
try:
|
| 442 |
-
async with aiosqlite.connect(db_path) as db:
|
| 443 |
-
await db.executemany("""
|
| 444 |
-
INSERT INTO focus_events (session_id, timestamp, is_focused, confidence, detection_data)
|
| 445 |
-
VALUES (?, ?, ?, ?, ?)
|
| 446 |
-
""", batch)
|
| 447 |
-
await db.execute("""
|
| 448 |
-
UPDATE focus_sessions
|
| 449 |
-
SET total_frames = total_frames + ?,
|
| 450 |
-
focused_frames = focused_frames + ?
|
| 451 |
-
WHERE id = ?
|
| 452 |
-
""", (total, focused, session_id))
|
| 453 |
-
await db.commit()
|
| 454 |
-
except Exception as e:
|
| 455 |
-
print(f"[DB] Flush error: {e}")
|
| 456 |
-
|
| 457 |
-
# ================ STARTUP/SHUTDOWN ================
|
| 458 |
-
|
| 459 |
-
pipelines = {
|
| 460 |
-
"geometric": None,
|
| 461 |
-
"mlp": None,
|
| 462 |
-
"hybrid": None,
|
| 463 |
-
"xgboost": None,
|
| 464 |
-
}
|
| 465 |
-
|
| 466 |
-
# Thread pool for CPU-bound inference so the event loop stays responsive.
|
| 467 |
-
_inference_executor = concurrent.futures.ThreadPoolExecutor(
|
| 468 |
-
max_workers=4,
|
| 469 |
-
thread_name_prefix="inference",
|
| 470 |
-
)
|
| 471 |
-
# One lock per pipeline so shared state (TemporalTracker, etc.) is not corrupted when
|
| 472 |
-
# multiple frames are processed in parallel by the thread pool.
|
| 473 |
-
_pipeline_locks = {name: threading.Lock() for name in ("geometric", "mlp", "hybrid", "xgboost")}
|
| 474 |
-
|
| 475 |
-
|
| 476 |
-
def _process_frame_safe(pipeline, frame, model_name: str):
|
| 477 |
-
"""Run process_frame in executor with per-pipeline lock."""
|
| 478 |
-
with _pipeline_locks[model_name]:
|
| 479 |
-
return pipeline.process_frame(frame)
|
| 480 |
-
|
| 481 |
-
@app.on_event("startup")
|
| 482 |
-
async def startup_event():
|
| 483 |
-
global pipelines, _cached_model_name
|
| 484 |
-
print(" Starting Focus Guard API...")
|
| 485 |
-
await init_database()
|
| 486 |
-
# Load cached model name from DB
|
| 487 |
-
async with aiosqlite.connect(db_path) as db:
|
| 488 |
-
cursor = await db.execute("SELECT model_name FROM user_settings WHERE id = 1")
|
| 489 |
-
row = await cursor.fetchone()
|
| 490 |
-
if row:
|
| 491 |
-
_cached_model_name = row[0]
|
| 492 |
-
print("[OK] Database initialized")
|
| 493 |
-
|
| 494 |
-
try:
|
| 495 |
-
pipelines["geometric"] = FaceMeshPipeline()
|
| 496 |
-
print("[OK] FaceMeshPipeline (geometric) loaded")
|
| 497 |
-
except Exception as e:
|
| 498 |
-
print(f"[WARN] FaceMeshPipeline unavailable: {e}")
|
| 499 |
-
|
| 500 |
-
try:
|
| 501 |
-
pipelines["mlp"] = MLPPipeline()
|
| 502 |
-
print("[OK] MLPPipeline loaded")
|
| 503 |
-
except Exception as e:
|
| 504 |
-
print(f"[ERR] Failed to load MLPPipeline: {e}")
|
| 505 |
-
|
| 506 |
-
try:
|
| 507 |
-
pipelines["hybrid"] = HybridFocusPipeline()
|
| 508 |
-
print("[OK] HybridFocusPipeline loaded")
|
| 509 |
-
except Exception as e:
|
| 510 |
-
print(f"[WARN] HybridFocusPipeline unavailable: {e}")
|
| 511 |
-
|
| 512 |
-
try:
|
| 513 |
-
pipelines["xgboost"] = XGBoostPipeline()
|
| 514 |
-
print("[OK] XGBoostPipeline loaded")
|
| 515 |
-
except Exception as e:
|
| 516 |
-
print(f"[ERR] Failed to load XGBoostPipeline: {e}")
|
| 517 |
-
|
| 518 |
-
@app.on_event("shutdown")
|
| 519 |
-
async def shutdown_event():
|
| 520 |
-
_inference_executor.shutdown(wait=False)
|
| 521 |
-
print(" Shutting down Focus Guard API...")
|
| 522 |
-
|
| 523 |
-
# ================ WEBRTC SIGNALING ================
|
| 524 |
-
|
| 525 |
-
@app.post("/api/webrtc/offer")
|
| 526 |
-
async def webrtc_offer(offer: dict):
|
| 527 |
-
try:
|
| 528 |
-
print(f"Received WebRTC offer")
|
| 529 |
-
|
| 530 |
-
pc = RTCPeerConnection()
|
| 531 |
-
pcs.add(pc)
|
| 532 |
-
|
| 533 |
-
session_id = await create_session()
|
| 534 |
-
print(f"Created session: {session_id}")
|
| 535 |
-
|
| 536 |
-
channel_ref = {"channel": None}
|
| 537 |
-
|
| 538 |
-
@pc.on("datachannel")
|
| 539 |
-
def on_datachannel(channel):
|
| 540 |
-
print(f"Data channel opened")
|
| 541 |
-
channel_ref["channel"] = channel
|
| 542 |
-
|
| 543 |
-
@pc.on("track")
|
| 544 |
-
def on_track(track):
|
| 545 |
-
print(f"Received track: {track.kind}")
|
| 546 |
-
if track.kind == "video":
|
| 547 |
-
local_track = VideoTransformTrack(track, session_id, lambda: channel_ref["channel"])
|
| 548 |
-
pc.addTrack(local_track)
|
| 549 |
-
print(f"Video track added")
|
| 550 |
-
|
| 551 |
-
@track.on("ended")
|
| 552 |
-
async def on_ended():
|
| 553 |
-
print(f"Track ended")
|
| 554 |
-
|
| 555 |
-
@pc.on("connectionstatechange")
|
| 556 |
-
async def on_connectionstatechange():
|
| 557 |
-
print(f"Connection state changed: {pc.connectionState}")
|
| 558 |
-
if pc.connectionState in ("failed", "closed", "disconnected"):
|
| 559 |
-
try:
|
| 560 |
-
await end_session(session_id)
|
| 561 |
-
except Exception as e:
|
| 562 |
-
print(f"⚠Error ending session: {e}")
|
| 563 |
-
pcs.discard(pc)
|
| 564 |
-
await pc.close()
|
| 565 |
-
|
| 566 |
-
await pc.setRemoteDescription(RTCSessionDescription(sdp=offer["sdp"], type=offer["type"]))
|
| 567 |
-
print(f"Remote description set")
|
| 568 |
-
|
| 569 |
-
answer = await pc.createAnswer()
|
| 570 |
-
await pc.setLocalDescription(answer)
|
| 571 |
-
print(f"Answer created")
|
| 572 |
-
|
| 573 |
-
await _wait_for_ice_gathering(pc)
|
| 574 |
-
print(f"ICE gathering complete")
|
| 575 |
-
|
| 576 |
-
return {"sdp": pc.localDescription.sdp, "type": pc.localDescription.type, "session_id": session_id}
|
| 577 |
-
|
| 578 |
-
except Exception as e:
|
| 579 |
-
print(f"WebRTC offer error: {e}")
|
| 580 |
-
import traceback
|
| 581 |
-
traceback.print_exc()
|
| 582 |
-
raise HTTPException(status_code=500, detail=f"WebRTC error: {str(e)}")
|
| 583 |
-
|
| 584 |
-
# ================ WEBSOCKET ================
|
| 585 |
-
|
| 586 |
-
@app.websocket("/ws/video")
|
| 587 |
-
async def websocket_endpoint(websocket: WebSocket):
|
| 588 |
-
await websocket.accept()
|
| 589 |
-
session_id = None
|
| 590 |
-
frame_count = 0
|
| 591 |
-
running = True
|
| 592 |
-
event_buffer = _EventBuffer(flush_interval=2.0)
|
| 593 |
-
|
| 594 |
-
# Latest frame slot — only the most recent frame is kept, older ones are dropped.
|
| 595 |
-
# Using a dict so nested functions can mutate without nonlocal issues.
|
| 596 |
-
_slot = {"frame": None}
|
| 597 |
-
_frame_ready = asyncio.Event()
|
| 598 |
-
|
| 599 |
-
async def _receive_loop():
|
| 600 |
-
"""Receive messages as fast as possible. Binary = frame, text = control."""
|
| 601 |
-
nonlocal session_id, running
|
| 602 |
-
try:
|
| 603 |
-
while running:
|
| 604 |
-
msg = await websocket.receive()
|
| 605 |
-
msg_type = msg.get("type", "")
|
| 606 |
-
|
| 607 |
-
if msg_type == "websocket.disconnect":
|
| 608 |
-
running = False
|
| 609 |
-
_frame_ready.set()
|
| 610 |
-
return
|
| 611 |
-
|
| 612 |
-
# Binary message → JPEG frame (fast path, no base64)
|
| 613 |
-
raw_bytes = msg.get("bytes")
|
| 614 |
-
if raw_bytes is not None and len(raw_bytes) > 0:
|
| 615 |
-
_slot["frame"] = raw_bytes
|
| 616 |
-
_frame_ready.set()
|
| 617 |
-
continue
|
| 618 |
-
|
| 619 |
-
# Text message → JSON control command (or legacy base64 frame)
|
| 620 |
-
text = msg.get("text")
|
| 621 |
-
if not text:
|
| 622 |
-
continue
|
| 623 |
-
data = json.loads(text)
|
| 624 |
-
|
| 625 |
-
if data["type"] == "frame":
|
| 626 |
-
# Legacy base64 path (fallback)
|
| 627 |
-
_slot["frame"] = base64.b64decode(data["image"])
|
| 628 |
-
_frame_ready.set()
|
| 629 |
-
|
| 630 |
-
elif data["type"] == "start_session":
|
| 631 |
-
session_id = await create_session()
|
| 632 |
-
event_buffer.start()
|
| 633 |
-
for p in pipelines.values():
|
| 634 |
-
if p is not None and hasattr(p, "reset_session"):
|
| 635 |
-
p.reset_session()
|
| 636 |
-
await websocket.send_json({"type": "session_started", "session_id": session_id})
|
| 637 |
-
|
| 638 |
-
elif data["type"] == "end_session":
|
| 639 |
-
if session_id:
|
| 640 |
-
await event_buffer.stop()
|
| 641 |
-
summary = await end_session(session_id)
|
| 642 |
-
if summary:
|
| 643 |
-
await websocket.send_json({"type": "session_ended", "summary": summary})
|
| 644 |
-
session_id = None
|
| 645 |
-
except WebSocketDisconnect:
|
| 646 |
-
running = False
|
| 647 |
-
_frame_ready.set()
|
| 648 |
-
except Exception as e:
|
| 649 |
-
print(f"[WS] receive error: {e}")
|
| 650 |
-
running = False
|
| 651 |
-
_frame_ready.set()
|
| 652 |
-
|
| 653 |
-
async def _process_loop():
|
| 654 |
-
"""Process only the latest frame, dropping stale ones."""
|
| 655 |
-
nonlocal frame_count, running
|
| 656 |
-
loop = asyncio.get_event_loop()
|
| 657 |
-
while running:
|
| 658 |
-
await _frame_ready.wait()
|
| 659 |
-
_frame_ready.clear()
|
| 660 |
-
if not running:
|
| 661 |
-
return
|
| 662 |
-
|
| 663 |
-
# Grab latest frame and clear slot
|
| 664 |
-
raw = _slot["frame"]
|
| 665 |
-
_slot["frame"] = None
|
| 666 |
-
if raw is None:
|
| 667 |
-
continue
|
| 668 |
-
|
| 669 |
-
try:
|
| 670 |
-
nparr = np.frombuffer(raw, np.uint8)
|
| 671 |
-
frame = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
| 672 |
-
if frame is None:
|
| 673 |
-
continue
|
| 674 |
-
frame = cv2.resize(frame, (640, 480))
|
| 675 |
-
|
| 676 |
-
model_name = _cached_model_name
|
| 677 |
-
if model_name not in pipelines or pipelines.get(model_name) is None:
|
| 678 |
-
model_name = "mlp"
|
| 679 |
-
active_pipeline = pipelines.get(model_name)
|
| 680 |
-
|
| 681 |
-
landmarks_list = None
|
| 682 |
-
if active_pipeline is not None:
|
| 683 |
-
out = await loop.run_in_executor(
|
| 684 |
-
_inference_executor,
|
| 685 |
-
_process_frame_safe,
|
| 686 |
-
active_pipeline,
|
| 687 |
-
frame,
|
| 688 |
-
model_name,
|
| 689 |
-
)
|
| 690 |
-
is_focused = out["is_focused"]
|
| 691 |
-
confidence = out.get("mlp_prob", out.get("raw_score", 0.0))
|
| 692 |
-
|
| 693 |
-
lm = out.get("landmarks")
|
| 694 |
-
if lm is not None:
|
| 695 |
-
# Send all 478 landmarks as flat array for tessellation drawing
|
| 696 |
-
landmarks_list = [
|
| 697 |
-
[round(float(lm[i, 0]), 3), round(float(lm[i, 1]), 3)]
|
| 698 |
-
for i in range(lm.shape[0])
|
| 699 |
-
]
|
| 700 |
-
|
| 701 |
-
if session_id:
|
| 702 |
-
event_buffer.add(session_id, is_focused, confidence, {
|
| 703 |
-
"s_face": out.get("s_face", 0.0),
|
| 704 |
-
"s_eye": out.get("s_eye", 0.0),
|
| 705 |
-
"mar": out.get("mar", 0.0),
|
| 706 |
-
"model": model_name,
|
| 707 |
-
})
|
| 708 |
-
else:
|
| 709 |
-
is_focused = False
|
| 710 |
-
confidence = 0.0
|
| 711 |
-
|
| 712 |
-
resp = {
|
| 713 |
-
"type": "detection",
|
| 714 |
-
"focused": is_focused,
|
| 715 |
-
"confidence": round(confidence, 3),
|
| 716 |
-
"detections": [],
|
| 717 |
-
"model": model_name,
|
| 718 |
-
"fc": frame_count,
|
| 719 |
-
"frame_count": frame_count,
|
| 720 |
-
}
|
| 721 |
-
if active_pipeline is not None:
|
| 722 |
-
# Send detailed metrics for HUD
|
| 723 |
-
if out.get("yaw") is not None:
|
| 724 |
-
resp["yaw"] = round(out["yaw"], 1)
|
| 725 |
-
resp["pitch"] = round(out["pitch"], 1)
|
| 726 |
-
resp["roll"] = round(out["roll"], 1)
|
| 727 |
-
if out.get("mar") is not None:
|
| 728 |
-
resp["mar"] = round(out["mar"], 3)
|
| 729 |
-
resp["sf"] = round(out.get("s_face", 0), 3)
|
| 730 |
-
resp["se"] = round(out.get("s_eye", 0), 3)
|
| 731 |
-
if landmarks_list is not None:
|
| 732 |
-
resp["lm"] = landmarks_list
|
| 733 |
-
await websocket.send_json(resp)
|
| 734 |
-
frame_count += 1
|
| 735 |
-
except Exception as e:
|
| 736 |
-
print(f"[WS] process error: {e}")
|
| 737 |
-
|
| 738 |
-
try:
|
| 739 |
-
await asyncio.gather(_receive_loop(), _process_loop())
|
| 740 |
-
except Exception:
|
| 741 |
-
pass
|
| 742 |
-
finally:
|
| 743 |
-
running = False
|
| 744 |
-
if session_id:
|
| 745 |
-
await event_buffer.stop()
|
| 746 |
-
await end_session(session_id)
|
| 747 |
-
|
| 748 |
-
# ================ API ENDPOINTS ================
|
| 749 |
-
|
| 750 |
-
@app.post("/api/sessions/start")
|
| 751 |
-
async def api_start_session():
|
| 752 |
-
session_id = await create_session()
|
| 753 |
-
return {"session_id": session_id}
|
| 754 |
-
|
| 755 |
-
@app.post("/api/sessions/end")
|
| 756 |
-
async def api_end_session(data: SessionEnd):
|
| 757 |
-
summary = await end_session(data.session_id)
|
| 758 |
-
if not summary: raise HTTPException(status_code=404, detail="Session not found")
|
| 759 |
-
return summary
|
| 760 |
-
|
| 761 |
-
@app.get("/api/sessions")
|
| 762 |
-
async def get_sessions(filter: str = "all", limit: int = 50, offset: int = 0):
|
| 763 |
-
async with aiosqlite.connect(db_path) as db:
|
| 764 |
-
db.row_factory = aiosqlite.Row
|
| 765 |
-
|
| 766 |
-
# NEW: If importing/exporting all, remove limit if special flag or high limit
|
| 767 |
-
# For simplicity: if limit is -1, return all
|
| 768 |
-
limit_clause = "LIMIT ? OFFSET ?"
|
| 769 |
-
params = []
|
| 770 |
-
|
| 771 |
-
base_query = "SELECT * FROM focus_sessions"
|
| 772 |
-
where_clause = ""
|
| 773 |
-
|
| 774 |
-
if filter == "today":
|
| 775 |
-
date_filter = datetime.now().replace(hour=0, minute=0, second=0, microsecond=0)
|
| 776 |
-
where_clause = " WHERE start_time >= ?"
|
| 777 |
-
params.append(date_filter.isoformat())
|
| 778 |
-
elif filter == "week":
|
| 779 |
-
date_filter = datetime.now() - timedelta(days=7)
|
| 780 |
-
where_clause = " WHERE start_time >= ?"
|
| 781 |
-
params.append(date_filter.isoformat())
|
| 782 |
-
elif filter == "month":
|
| 783 |
-
date_filter = datetime.now() - timedelta(days=30)
|
| 784 |
-
where_clause = " WHERE start_time >= ?"
|
| 785 |
-
params.append(date_filter.isoformat())
|
| 786 |
-
elif filter == "all":
|
| 787 |
-
# Just ensure we only get completed sessions or all sessions
|
| 788 |
-
where_clause = " WHERE end_time IS NOT NULL"
|
| 789 |
-
|
| 790 |
-
query = f"{base_query}{where_clause} ORDER BY start_time DESC"
|
| 791 |
-
|
| 792 |
-
# Handle Limit for Exports
|
| 793 |
-
if limit == -1:
|
| 794 |
-
# No limit clause for export
|
| 795 |
-
pass
|
| 796 |
-
else:
|
| 797 |
-
query += f" {limit_clause}"
|
| 798 |
-
params.extend([limit, offset])
|
| 799 |
-
|
| 800 |
-
cursor = await db.execute(query, tuple(params))
|
| 801 |
-
rows = await cursor.fetchall()
|
| 802 |
-
return [dict(row) for row in rows]
|
| 803 |
-
|
| 804 |
-
# --- NEW: Import Endpoint ---
|
| 805 |
-
@app.post("/api/import")
|
| 806 |
-
async def import_sessions(sessions: List[dict]):
|
| 807 |
-
count = 0
|
| 808 |
-
try:
|
| 809 |
-
async with aiosqlite.connect(db_path) as db:
|
| 810 |
-
for session in sessions:
|
| 811 |
-
# Use .get() to handle potential missing fields from older versions or edits
|
| 812 |
-
await db.execute("""
|
| 813 |
-
INSERT INTO focus_sessions (start_time, end_time, duration_seconds, focus_score, total_frames, focused_frames, created_at)
|
| 814 |
-
VALUES (?, ?, ?, ?, ?, ?, ?)
|
| 815 |
-
""", (
|
| 816 |
-
session.get('start_time'),
|
| 817 |
-
session.get('end_time'),
|
| 818 |
-
session.get('duration_seconds', 0),
|
| 819 |
-
session.get('focus_score', 0.0),
|
| 820 |
-
session.get('total_frames', 0),
|
| 821 |
-
session.get('focused_frames', 0),
|
| 822 |
-
session.get('created_at', session.get('start_time'))
|
| 823 |
-
))
|
| 824 |
-
count += 1
|
| 825 |
-
await db.commit()
|
| 826 |
-
return {"status": "success", "count": count}
|
| 827 |
-
except Exception as e:
|
| 828 |
-
print(f"Import Error: {e}")
|
| 829 |
-
return {"status": "error", "message": str(e)}
|
| 830 |
-
|
| 831 |
-
# --- NEW: Clear History Endpoint ---
|
| 832 |
-
@app.delete("/api/history")
|
| 833 |
-
async def clear_history():
|
| 834 |
-
try:
|
| 835 |
-
async with aiosqlite.connect(db_path) as db:
|
| 836 |
-
# Delete events first (foreign key good practice)
|
| 837 |
-
await db.execute("DELETE FROM focus_events")
|
| 838 |
-
await db.execute("DELETE FROM focus_sessions")
|
| 839 |
-
await db.commit()
|
| 840 |
-
return {"status": "success", "message": "History cleared"}
|
| 841 |
-
except Exception as e:
|
| 842 |
-
return {"status": "error", "message": str(e)}
|
| 843 |
-
|
| 844 |
-
@app.get("/api/sessions/{session_id}")
|
| 845 |
-
async def get_session(session_id: int):
|
| 846 |
-
async with aiosqlite.connect(db_path) as db:
|
| 847 |
-
db.row_factory = aiosqlite.Row
|
| 848 |
-
cursor = await db.execute("SELECT * FROM focus_sessions WHERE id = ?", (session_id,))
|
| 849 |
-
row = await cursor.fetchone()
|
| 850 |
-
if not row: raise HTTPException(status_code=404, detail="Session not found")
|
| 851 |
-
session = dict(row)
|
| 852 |
-
cursor = await db.execute("SELECT * FROM focus_events WHERE session_id = ? ORDER BY timestamp", (session_id,))
|
| 853 |
-
events = [dict(r) for r in await cursor.fetchall()]
|
| 854 |
-
session['events'] = events
|
| 855 |
-
return session
|
| 856 |
-
|
| 857 |
-
@app.get("/api/settings")
|
| 858 |
-
async def get_settings():
|
| 859 |
-
async with aiosqlite.connect(db_path) as db:
|
| 860 |
-
db.row_factory = aiosqlite.Row
|
| 861 |
-
cursor = await db.execute("SELECT * FROM user_settings WHERE id = 1")
|
| 862 |
-
row = await cursor.fetchone()
|
| 863 |
-
if row: return dict(row)
|
| 864 |
-
else: return {'sensitivity': 6, 'notification_enabled': True, 'notification_threshold': 30, 'frame_rate': 30, 'model_name': 'mlp'}
|
| 865 |
-
|
| 866 |
-
@app.put("/api/settings")
|
| 867 |
-
async def update_settings(settings: SettingsUpdate):
|
| 868 |
-
async with aiosqlite.connect(db_path) as db:
|
| 869 |
-
cursor = await db.execute("SELECT id FROM user_settings WHERE id = 1")
|
| 870 |
-
exists = await cursor.fetchone()
|
| 871 |
-
if not exists:
|
| 872 |
-
await db.execute("INSERT INTO user_settings (id, sensitivity) VALUES (1, 6)")
|
| 873 |
-
await db.commit()
|
| 874 |
-
|
| 875 |
-
updates = []
|
| 876 |
-
params = []
|
| 877 |
-
if settings.sensitivity is not None:
|
| 878 |
-
updates.append("sensitivity = ?")
|
| 879 |
-
params.append(max(1, min(10, settings.sensitivity)))
|
| 880 |
-
if settings.notification_enabled is not None:
|
| 881 |
-
updates.append("notification_enabled = ?")
|
| 882 |
-
params.append(settings.notification_enabled)
|
| 883 |
-
if settings.notification_threshold is not None:
|
| 884 |
-
updates.append("notification_threshold = ?")
|
| 885 |
-
params.append(max(5, min(300, settings.notification_threshold)))
|
| 886 |
-
if settings.frame_rate is not None:
|
| 887 |
-
updates.append("frame_rate = ?")
|
| 888 |
-
params.append(max(5, min(60, settings.frame_rate)))
|
| 889 |
-
if settings.model_name is not None and settings.model_name in pipelines and pipelines[settings.model_name] is not None:
|
| 890 |
-
updates.append("model_name = ?")
|
| 891 |
-
params.append(settings.model_name)
|
| 892 |
-
global _cached_model_name
|
| 893 |
-
_cached_model_name = settings.model_name
|
| 894 |
-
|
| 895 |
-
if updates:
|
| 896 |
-
query = f"UPDATE user_settings SET {', '.join(updates)} WHERE id = 1"
|
| 897 |
-
await db.execute(query, params)
|
| 898 |
-
await db.commit()
|
| 899 |
-
return {"status": "success", "updated": len(updates) > 0}
|
| 900 |
-
|
| 901 |
-
@app.get("/api/stats/summary")
|
| 902 |
-
async def get_stats_summary():
|
| 903 |
-
async with aiosqlite.connect(db_path) as db:
|
| 904 |
-
cursor = await db.execute("SELECT COUNT(*) FROM focus_sessions WHERE end_time IS NOT NULL")
|
| 905 |
-
total_sessions = (await cursor.fetchone())[0]
|
| 906 |
-
cursor = await db.execute("SELECT SUM(duration_seconds) FROM focus_sessions WHERE end_time IS NOT NULL")
|
| 907 |
-
total_focus_time = (await cursor.fetchone())[0] or 0
|
| 908 |
-
cursor = await db.execute("SELECT AVG(focus_score) FROM focus_sessions WHERE end_time IS NOT NULL")
|
| 909 |
-
avg_focus_score = (await cursor.fetchone())[0] or 0.0
|
| 910 |
-
cursor = await db.execute("SELECT DISTINCT DATE(start_time) as session_date FROM focus_sessions WHERE end_time IS NOT NULL ORDER BY session_date DESC")
|
| 911 |
-
dates = [row[0] for row in await cursor.fetchall()]
|
| 912 |
-
|
| 913 |
-
streak_days = 0
|
| 914 |
-
if dates:
|
| 915 |
-
current_date = datetime.now().date()
|
| 916 |
-
for i, date_str in enumerate(dates):
|
| 917 |
-
session_date = datetime.fromisoformat(date_str).date()
|
| 918 |
-
expected_date = current_date - timedelta(days=i)
|
| 919 |
-
if session_date == expected_date: streak_days += 1
|
| 920 |
-
else: break
|
| 921 |
-
return {
|
| 922 |
-
'total_sessions': total_sessions,
|
| 923 |
-
'total_focus_time': int(total_focus_time),
|
| 924 |
-
'avg_focus_score': round(avg_focus_score, 3),
|
| 925 |
-
'streak_days': streak_days
|
| 926 |
-
}
|
| 927 |
-
|
| 928 |
-
@app.get("/api/models")
|
| 929 |
-
async def get_available_models():
|
| 930 |
-
"""Return list of loaded model names and which is currently active."""
|
| 931 |
-
available = [name for name, p in pipelines.items() if p is not None]
|
| 932 |
-
async with aiosqlite.connect(db_path) as db:
|
| 933 |
-
cursor = await db.execute("SELECT model_name FROM user_settings WHERE id = 1")
|
| 934 |
-
row = await cursor.fetchone()
|
| 935 |
-
current = row[0] if row else "mlp"
|
| 936 |
-
if current not in available and available:
|
| 937 |
-
current = available[0]
|
| 938 |
-
return {"available": available, "current": current}
|
| 939 |
-
|
| 940 |
-
@app.get("/api/mesh-topology")
|
| 941 |
-
async def get_mesh_topology():
|
| 942 |
-
"""Return tessellation edge pairs for client-side face mesh drawing (cached by client)."""
|
| 943 |
-
return {"tessellation": _TESSELATION_CONNS}
|
| 944 |
-
|
| 945 |
-
@app.get("/health")
|
| 946 |
-
async def health_check():
|
| 947 |
-
available = [name for name, p in pipelines.items() if p is not None]
|
| 948 |
-
return {"status": "healthy", "models_loaded": available, "database": os.path.exists(db_path)}
|
| 949 |
-
|
| 950 |
-
# ================ STATIC FILES (SPA SUPPORT) ================
|
| 951 |
-
|
| 952 |
-
# Resolve frontend dir from this file so it works regardless of cwd.
|
| 953 |
-
# Prefer a built `dist/` app when present, otherwise fall back to `static/`.
|
| 954 |
-
_BASE_DIR = Path(__file__).resolve().parent
|
| 955 |
-
_DIST_DIR = _BASE_DIR / "dist"
|
| 956 |
-
_STATIC_DIR = _BASE_DIR / "static"
|
| 957 |
-
_FRONTEND_DIR = _DIST_DIR if (_DIST_DIR / "index.html").is_file() else _STATIC_DIR
|
| 958 |
-
_ASSETS_DIR = _FRONTEND_DIR / "assets"
|
| 959 |
-
|
| 960 |
-
# 1. Mount the assets folder (JS/CSS) first so /assets/* is never caught by catch-all
|
| 961 |
-
if _ASSETS_DIR.is_dir():
|
| 962 |
-
app.mount("/assets", StaticFiles(directory=str(_ASSETS_DIR)), name="assets")
|
| 963 |
-
|
| 964 |
-
# 2. Catch-all for SPA: serve index.html for app routes, never for /assets (would break JS MIME type)
|
| 965 |
-
@app.get("/{full_path:path}")
|
| 966 |
-
async def serve_react_app(full_path: str, request: Request):
|
| 967 |
-
if full_path.startswith("api") or full_path.startswith("ws"):
|
| 968 |
-
raise HTTPException(status_code=404, detail="Not Found")
|
| 969 |
-
# Don't serve HTML for asset paths; let them 404 so we don't break module script loading
|
| 970 |
-
if full_path.startswith("assets") or full_path.startswith("assets/"):
|
| 971 |
-
raise HTTPException(status_code=404, detail="Not Found")
|
| 972 |
-
|
| 973 |
-
file_path = _FRONTEND_DIR / full_path
|
| 974 |
-
if full_path and file_path.is_file():
|
| 975 |
-
return FileResponse(str(file_path))
|
| 976 |
-
|
| 977 |
-
index_path = _FRONTEND_DIR / "index.html"
|
| 978 |
-
if index_path.is_file():
|
| 979 |
-
return FileResponse(str(index_path))
|
| 980 |
-
return {"message": "React app not found. Please run 'npm run build' and copy dist to static if needed."}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main_yolo.py
DELETED
|
@@ -1,601 +0,0 @@
|
|
| 1 |
-
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException, Request
|
| 2 |
-
from fastapi.staticfiles import StaticFiles
|
| 3 |
-
from fastapi.responses import FileResponse
|
| 4 |
-
from fastapi.middleware.cors import CORSMiddleware
|
| 5 |
-
from pydantic import BaseModel
|
| 6 |
-
from typing import Optional, List, Any
|
| 7 |
-
import base64
|
| 8 |
-
import cv2
|
| 9 |
-
import numpy as np
|
| 10 |
-
import aiosqlite
|
| 11 |
-
import json
|
| 12 |
-
from datetime import datetime, timedelta
|
| 13 |
-
import math
|
| 14 |
-
import os
|
| 15 |
-
from pathlib import Path
|
| 16 |
-
from typing import Callable
|
| 17 |
-
import asyncio
|
| 18 |
-
|
| 19 |
-
from aiortc import RTCPeerConnection, RTCSessionDescription, VideoStreamTrack
|
| 20 |
-
from av import VideoFrame
|
| 21 |
-
|
| 22 |
-
from ui.pipeline import MLPPipeline
|
| 23 |
-
|
| 24 |
-
# Initialize FastAPI app
|
| 25 |
-
app = FastAPI(title="Focus Guard API")
|
| 26 |
-
|
| 27 |
-
# Add CORS middleware
|
| 28 |
-
app.add_middleware(
|
| 29 |
-
CORSMiddleware,
|
| 30 |
-
allow_origins=["*"],
|
| 31 |
-
allow_credentials=True,
|
| 32 |
-
allow_methods=["*"],
|
| 33 |
-
allow_headers=["*"],
|
| 34 |
-
)
|
| 35 |
-
|
| 36 |
-
# Global variables
|
| 37 |
-
db_path = "focus_guard.db"
|
| 38 |
-
pcs = set()
|
| 39 |
-
|
| 40 |
-
async def _wait_for_ice_gathering(pc: RTCPeerConnection):
|
| 41 |
-
if pc.iceGatheringState == "complete":
|
| 42 |
-
return
|
| 43 |
-
done = asyncio.Event()
|
| 44 |
-
|
| 45 |
-
@pc.on("icegatheringstatechange")
|
| 46 |
-
def _on_state_change():
|
| 47 |
-
if pc.iceGatheringState == "complete":
|
| 48 |
-
done.set()
|
| 49 |
-
|
| 50 |
-
await done.wait()
|
| 51 |
-
|
| 52 |
-
# ================ DATABASE MODELS ================
|
| 53 |
-
|
| 54 |
-
async def init_database():
|
| 55 |
-
"""Initialize SQLite database with required tables"""
|
| 56 |
-
async with aiosqlite.connect(db_path) as db:
|
| 57 |
-
# FocusSessions table
|
| 58 |
-
await db.execute("""
|
| 59 |
-
CREATE TABLE IF NOT EXISTS focus_sessions (
|
| 60 |
-
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 61 |
-
start_time TIMESTAMP NOT NULL,
|
| 62 |
-
end_time TIMESTAMP,
|
| 63 |
-
duration_seconds INTEGER DEFAULT 0,
|
| 64 |
-
focus_score REAL DEFAULT 0.0,
|
| 65 |
-
total_frames INTEGER DEFAULT 0,
|
| 66 |
-
focused_frames INTEGER DEFAULT 0,
|
| 67 |
-
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
| 68 |
-
)
|
| 69 |
-
""")
|
| 70 |
-
|
| 71 |
-
# FocusEvents table
|
| 72 |
-
await db.execute("""
|
| 73 |
-
CREATE TABLE IF NOT EXISTS focus_events (
|
| 74 |
-
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 75 |
-
session_id INTEGER NOT NULL,
|
| 76 |
-
timestamp TIMESTAMP NOT NULL,
|
| 77 |
-
is_focused BOOLEAN NOT NULL,
|
| 78 |
-
confidence REAL NOT NULL,
|
| 79 |
-
detection_data TEXT,
|
| 80 |
-
FOREIGN KEY (session_id) REFERENCES focus_sessions (id)
|
| 81 |
-
)
|
| 82 |
-
""")
|
| 83 |
-
|
| 84 |
-
# UserSettings table
|
| 85 |
-
await db.execute("""
|
| 86 |
-
CREATE TABLE IF NOT EXISTS user_settings (
|
| 87 |
-
id INTEGER PRIMARY KEY CHECK (id = 1),
|
| 88 |
-
sensitivity INTEGER DEFAULT 6,
|
| 89 |
-
notification_enabled BOOLEAN DEFAULT 1,
|
| 90 |
-
notification_threshold INTEGER DEFAULT 30,
|
| 91 |
-
frame_rate INTEGER DEFAULT 30,
|
| 92 |
-
model_name TEXT DEFAULT 'yolov8n.pt'
|
| 93 |
-
)
|
| 94 |
-
""")
|
| 95 |
-
|
| 96 |
-
# Insert default settings if not exists
|
| 97 |
-
await db.execute("""
|
| 98 |
-
INSERT OR IGNORE INTO user_settings (id, sensitivity, notification_enabled, notification_threshold, frame_rate, model_name)
|
| 99 |
-
VALUES (1, 6, 1, 30, 30, 'yolov8n.pt')
|
| 100 |
-
""")
|
| 101 |
-
|
| 102 |
-
await db.commit()
|
| 103 |
-
|
| 104 |
-
# ================ PYDANTIC MODELS ================
|
| 105 |
-
|
| 106 |
-
class SessionCreate(BaseModel):
|
| 107 |
-
pass
|
| 108 |
-
|
| 109 |
-
class SessionEnd(BaseModel):
|
| 110 |
-
session_id: int
|
| 111 |
-
|
| 112 |
-
class SettingsUpdate(BaseModel):
|
| 113 |
-
sensitivity: Optional[int] = None
|
| 114 |
-
notification_enabled: Optional[bool] = None
|
| 115 |
-
notification_threshold: Optional[int] = None
|
| 116 |
-
frame_rate: Optional[int] = None
|
| 117 |
-
|
| 118 |
-
class VideoTransformTrack(VideoStreamTrack):
|
| 119 |
-
def __init__(self, track, session_id: int, get_channel: Callable[[], Any]):
|
| 120 |
-
super().__init__()
|
| 121 |
-
self.track = track
|
| 122 |
-
self.session_id = session_id
|
| 123 |
-
self.get_channel = get_channel
|
| 124 |
-
self.last_inference_time = 0
|
| 125 |
-
self.min_inference_interval = 1 / 60
|
| 126 |
-
self.last_frame = None
|
| 127 |
-
|
| 128 |
-
async def recv(self):
|
| 129 |
-
frame = await self.track.recv()
|
| 130 |
-
img = frame.to_ndarray(format="bgr24")
|
| 131 |
-
if img is None:
|
| 132 |
-
return frame
|
| 133 |
-
|
| 134 |
-
# Normalize size for inference/drawing
|
| 135 |
-
img = cv2.resize(img, (640, 480))
|
| 136 |
-
|
| 137 |
-
now = datetime.now().timestamp()
|
| 138 |
-
do_infer = (now - self.last_inference_time) >= self.min_inference_interval
|
| 139 |
-
|
| 140 |
-
if do_infer and mlp_pipeline is not None:
|
| 141 |
-
self.last_inference_time = now
|
| 142 |
-
out = mlp_pipeline.process_frame(img)
|
| 143 |
-
is_focused = out["is_focused"]
|
| 144 |
-
confidence = out["mlp_prob"]
|
| 145 |
-
metadata = {"s_face": out["s_face"], "s_eye": out["s_eye"], "mar": out["mar"]}
|
| 146 |
-
detections = []
|
| 147 |
-
status_text = "FOCUSED" if is_focused else "NOT FOCUSED"
|
| 148 |
-
color = (0, 255, 0) if is_focused else (0, 0, 255)
|
| 149 |
-
cv2.putText(img, status_text, (10, 25), cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2)
|
| 150 |
-
cv2.putText(img, f"Confidence: {confidence * 100:.1f}%", (10, 50), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
|
| 151 |
-
|
| 152 |
-
if self.session_id:
|
| 153 |
-
await store_focus_event(self.session_id, is_focused, confidence, metadata)
|
| 154 |
-
|
| 155 |
-
channel = self.get_channel()
|
| 156 |
-
if channel and channel.readyState == "open":
|
| 157 |
-
try:
|
| 158 |
-
channel.send(json.dumps({"type": "detection", "focused": is_focused, "confidence": round(confidence, 3), "detections": detections}))
|
| 159 |
-
except Exception:
|
| 160 |
-
pass
|
| 161 |
-
|
| 162 |
-
self.last_frame = img
|
| 163 |
-
elif self.last_frame is not None:
|
| 164 |
-
img = self.last_frame
|
| 165 |
-
|
| 166 |
-
new_frame = VideoFrame.from_ndarray(img, format="bgr24")
|
| 167 |
-
new_frame.pts = frame.pts
|
| 168 |
-
new_frame.time_base = frame.time_base
|
| 169 |
-
return new_frame
|
| 170 |
-
|
| 171 |
-
# ================ DATABASE OPERATIONS ================
|
| 172 |
-
|
| 173 |
-
async def create_session():
|
| 174 |
-
async with aiosqlite.connect(db_path) as db:
|
| 175 |
-
cursor = await db.execute(
|
| 176 |
-
"INSERT INTO focus_sessions (start_time) VALUES (?)",
|
| 177 |
-
(datetime.now().isoformat(),)
|
| 178 |
-
)
|
| 179 |
-
await db.commit()
|
| 180 |
-
return cursor.lastrowid
|
| 181 |
-
|
| 182 |
-
async def end_session(session_id: int):
|
| 183 |
-
async with aiosqlite.connect(db_path) as db:
|
| 184 |
-
cursor = await db.execute(
|
| 185 |
-
"SELECT start_time, total_frames, focused_frames FROM focus_sessions WHERE id = ?",
|
| 186 |
-
(session_id,)
|
| 187 |
-
)
|
| 188 |
-
row = await cursor.fetchone()
|
| 189 |
-
|
| 190 |
-
if not row:
|
| 191 |
-
return None
|
| 192 |
-
|
| 193 |
-
start_time_str, total_frames, focused_frames = row
|
| 194 |
-
start_time = datetime.fromisoformat(start_time_str)
|
| 195 |
-
end_time = datetime.now()
|
| 196 |
-
duration = (end_time - start_time).total_seconds()
|
| 197 |
-
focus_score = focused_frames / total_frames if total_frames > 0 else 0.0
|
| 198 |
-
|
| 199 |
-
await db.execute("""
|
| 200 |
-
UPDATE focus_sessions
|
| 201 |
-
SET end_time = ?, duration_seconds = ?, focus_score = ?
|
| 202 |
-
WHERE id = ?
|
| 203 |
-
""", (end_time.isoformat(), int(duration), focus_score, session_id))
|
| 204 |
-
|
| 205 |
-
await db.commit()
|
| 206 |
-
|
| 207 |
-
return {
|
| 208 |
-
'session_id': session_id,
|
| 209 |
-
'start_time': start_time_str,
|
| 210 |
-
'end_time': end_time.isoformat(),
|
| 211 |
-
'duration_seconds': int(duration),
|
| 212 |
-
'focus_score': round(focus_score, 3),
|
| 213 |
-
'total_frames': total_frames,
|
| 214 |
-
'focused_frames': focused_frames
|
| 215 |
-
}
|
| 216 |
-
|
| 217 |
-
async def store_focus_event(session_id: int, is_focused: bool, confidence: float, metadata: dict):
|
| 218 |
-
async with aiosqlite.connect(db_path) as db:
|
| 219 |
-
await db.execute("""
|
| 220 |
-
INSERT INTO focus_events (session_id, timestamp, is_focused, confidence, detection_data)
|
| 221 |
-
VALUES (?, ?, ?, ?, ?)
|
| 222 |
-
""", (session_id, datetime.now().isoformat(), is_focused, confidence, json.dumps(metadata)))
|
| 223 |
-
|
| 224 |
-
await db.execute(f"""
|
| 225 |
-
UPDATE focus_sessions
|
| 226 |
-
SET total_frames = total_frames + 1,
|
| 227 |
-
focused_frames = focused_frames + {1 if is_focused else 0}
|
| 228 |
-
WHERE id = ?
|
| 229 |
-
""", (session_id,))
|
| 230 |
-
await db.commit()
|
| 231 |
-
|
| 232 |
-
# ================ STARTUP/SHUTDOWN ================
|
| 233 |
-
|
| 234 |
-
mlp_pipeline = None
|
| 235 |
-
|
| 236 |
-
@app.on_event("startup")
|
| 237 |
-
async def startup_event():
|
| 238 |
-
global mlp_pipeline
|
| 239 |
-
print(" Starting Focus Guard API...")
|
| 240 |
-
await init_database()
|
| 241 |
-
print("[OK] Database initialized")
|
| 242 |
-
|
| 243 |
-
mlp_pipeline = MLPPipeline()
|
| 244 |
-
print("[OK] MLPPipeline loaded")
|
| 245 |
-
|
| 246 |
-
@app.on_event("shutdown")
|
| 247 |
-
async def shutdown_event():
|
| 248 |
-
print(" Shutting down Focus Guard API...")
|
| 249 |
-
|
| 250 |
-
# ================ WEBRTC SIGNALING ================
|
| 251 |
-
|
| 252 |
-
@app.post("/api/webrtc/offer")
|
| 253 |
-
async def webrtc_offer(offer: dict):
|
| 254 |
-
try:
|
| 255 |
-
print(f"Received WebRTC offer")
|
| 256 |
-
|
| 257 |
-
pc = RTCPeerConnection()
|
| 258 |
-
pcs.add(pc)
|
| 259 |
-
|
| 260 |
-
session_id = await create_session()
|
| 261 |
-
print(f"Created session: {session_id}")
|
| 262 |
-
|
| 263 |
-
channel_ref = {"channel": None}
|
| 264 |
-
|
| 265 |
-
@pc.on("datachannel")
|
| 266 |
-
def on_datachannel(channel):
|
| 267 |
-
print(f"Data channel opened")
|
| 268 |
-
channel_ref["channel"] = channel
|
| 269 |
-
|
| 270 |
-
@pc.on("track")
|
| 271 |
-
def on_track(track):
|
| 272 |
-
print(f"Received track: {track.kind}")
|
| 273 |
-
if track.kind == "video":
|
| 274 |
-
local_track = VideoTransformTrack(track, session_id, lambda: channel_ref["channel"])
|
| 275 |
-
pc.addTrack(local_track)
|
| 276 |
-
print(f"Video track added")
|
| 277 |
-
|
| 278 |
-
@track.on("ended")
|
| 279 |
-
async def on_ended():
|
| 280 |
-
print(f"Track ended")
|
| 281 |
-
|
| 282 |
-
@pc.on("connectionstatechange")
|
| 283 |
-
async def on_connectionstatechange():
|
| 284 |
-
print(f"Connection state changed: {pc.connectionState}")
|
| 285 |
-
if pc.connectionState in ("failed", "closed", "disconnected"):
|
| 286 |
-
try:
|
| 287 |
-
await end_session(session_id)
|
| 288 |
-
except Exception as e:
|
| 289 |
-
print(f"⚠Error ending session: {e}")
|
| 290 |
-
pcs.discard(pc)
|
| 291 |
-
await pc.close()
|
| 292 |
-
|
| 293 |
-
await pc.setRemoteDescription(RTCSessionDescription(sdp=offer["sdp"], type=offer["type"]))
|
| 294 |
-
print(f"Remote description set")
|
| 295 |
-
|
| 296 |
-
answer = await pc.createAnswer()
|
| 297 |
-
await pc.setLocalDescription(answer)
|
| 298 |
-
print(f"Answer created")
|
| 299 |
-
|
| 300 |
-
await _wait_for_ice_gathering(pc)
|
| 301 |
-
print(f"ICE gathering complete")
|
| 302 |
-
|
| 303 |
-
return {"sdp": pc.localDescription.sdp, "type": pc.localDescription.type, "session_id": session_id}
|
| 304 |
-
|
| 305 |
-
except Exception as e:
|
| 306 |
-
print(f"WebRTC offer error: {e}")
|
| 307 |
-
import traceback
|
| 308 |
-
traceback.print_exc()
|
| 309 |
-
raise HTTPException(status_code=500, detail=f"WebRTC error: {str(e)}")
|
| 310 |
-
|
| 311 |
-
# ================ WEBSOCKET ================
|
| 312 |
-
|
| 313 |
-
@app.websocket("/ws/video")
|
| 314 |
-
async def websocket_endpoint(websocket: WebSocket):
|
| 315 |
-
await websocket.accept()
|
| 316 |
-
session_id = None
|
| 317 |
-
frame_count = 0
|
| 318 |
-
last_inference_time = 0
|
| 319 |
-
min_inference_interval = 1 / 60
|
| 320 |
-
|
| 321 |
-
try:
|
| 322 |
-
async with aiosqlite.connect(db_path) as db:
|
| 323 |
-
cursor = await db.execute("SELECT sensitivity FROM user_settings WHERE id = 1")
|
| 324 |
-
row = await cursor.fetchone()
|
| 325 |
-
sensitivity = row[0] if row else 6
|
| 326 |
-
|
| 327 |
-
while True:
|
| 328 |
-
data = await websocket.receive_json()
|
| 329 |
-
|
| 330 |
-
if data['type'] == 'frame':
|
| 331 |
-
from time import time
|
| 332 |
-
current_time = time()
|
| 333 |
-
if current_time - last_inference_time < min_inference_interval:
|
| 334 |
-
await websocket.send_json({'type': 'ack', 'frame_count': frame_count})
|
| 335 |
-
continue
|
| 336 |
-
last_inference_time = current_time
|
| 337 |
-
|
| 338 |
-
try:
|
| 339 |
-
img_data = base64.b64decode(data['image'])
|
| 340 |
-
nparr = np.frombuffer(img_data, np.uint8)
|
| 341 |
-
frame = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
| 342 |
-
|
| 343 |
-
if frame is None: continue
|
| 344 |
-
frame = cv2.resize(frame, (640, 480))
|
| 345 |
-
|
| 346 |
-
if mlp_pipeline is not None:
|
| 347 |
-
out = mlp_pipeline.process_frame(frame)
|
| 348 |
-
|
| 349 |
-
is_focused = out["is_focused"]
|
| 350 |
-
confidence = out["mlp_prob"]
|
| 351 |
-
metadata = {
|
| 352 |
-
"s_face": out["s_face"],
|
| 353 |
-
"s_eye": out["s_eye"],
|
| 354 |
-
"mar": out["mar"]
|
| 355 |
-
}
|
| 356 |
-
else:
|
| 357 |
-
is_focused = False
|
| 358 |
-
confidence = 0.0
|
| 359 |
-
metadata = {}
|
| 360 |
-
|
| 361 |
-
detections = []
|
| 362 |
-
|
| 363 |
-
if session_id:
|
| 364 |
-
await store_focus_event(session_id, is_focused, confidence, metadata)
|
| 365 |
-
|
| 366 |
-
await websocket.send_json({
|
| 367 |
-
'type': 'detection',
|
| 368 |
-
'focused': is_focused,
|
| 369 |
-
'confidence': round(confidence, 3),
|
| 370 |
-
'detections': detections,
|
| 371 |
-
'frame_count': frame_count
|
| 372 |
-
})
|
| 373 |
-
frame_count += 1
|
| 374 |
-
except Exception as e:
|
| 375 |
-
print(f"Error processing frame: {e}")
|
| 376 |
-
await websocket.send_json({'type': 'error', 'message': str(e)})
|
| 377 |
-
|
| 378 |
-
elif data['type'] == 'start_session':
|
| 379 |
-
session_id = await create_session()
|
| 380 |
-
await websocket.send_json({'type': 'session_started', 'session_id': session_id})
|
| 381 |
-
|
| 382 |
-
elif data['type'] == 'end_session':
|
| 383 |
-
if session_id:
|
| 384 |
-
print(f"Ending session {session_id}...")
|
| 385 |
-
summary = await end_session(session_id)
|
| 386 |
-
print(f"Session summary: {summary}")
|
| 387 |
-
if summary:
|
| 388 |
-
await websocket.send_json({'type': 'session_ended', 'summary': summary})
|
| 389 |
-
print("Session ended message sent")
|
| 390 |
-
else:
|
| 391 |
-
print("Warning: No summary returned")
|
| 392 |
-
session_id = None
|
| 393 |
-
else:
|
| 394 |
-
print("Warning: end_session called but no active session_id")
|
| 395 |
-
|
| 396 |
-
except WebSocketDisconnect:
|
| 397 |
-
if session_id: await end_session(session_id)
|
| 398 |
-
except Exception as e:
|
| 399 |
-
if websocket.client_state.value == 1: await websocket.close()
|
| 400 |
-
|
| 401 |
-
# ================ API ENDPOINTS ================
|
| 402 |
-
|
| 403 |
-
@app.post("/api/sessions/start")
|
| 404 |
-
async def api_start_session():
|
| 405 |
-
session_id = await create_session()
|
| 406 |
-
return {"session_id": session_id}
|
| 407 |
-
|
| 408 |
-
@app.post("/api/sessions/end")
|
| 409 |
-
async def api_end_session(data: SessionEnd):
|
| 410 |
-
summary = await end_session(data.session_id)
|
| 411 |
-
if not summary: raise HTTPException(status_code=404, detail="Session not found")
|
| 412 |
-
return summary
|
| 413 |
-
|
| 414 |
-
@app.get("/api/sessions")
|
| 415 |
-
async def get_sessions(filter: str = "all", limit: int = 50, offset: int = 0):
|
| 416 |
-
async with aiosqlite.connect(db_path) as db:
|
| 417 |
-
db.row_factory = aiosqlite.Row
|
| 418 |
-
|
| 419 |
-
# NEW: If importing/exporting all, remove limit if special flag or high limit
|
| 420 |
-
# For simplicity: if limit is -1, return all
|
| 421 |
-
limit_clause = "LIMIT ? OFFSET ?"
|
| 422 |
-
params = []
|
| 423 |
-
|
| 424 |
-
base_query = "SELECT * FROM focus_sessions"
|
| 425 |
-
where_clause = ""
|
| 426 |
-
|
| 427 |
-
if filter == "today":
|
| 428 |
-
date_filter = datetime.now().replace(hour=0, minute=0, second=0, microsecond=0)
|
| 429 |
-
where_clause = " WHERE start_time >= ?"
|
| 430 |
-
params.append(date_filter.isoformat())
|
| 431 |
-
elif filter == "week":
|
| 432 |
-
date_filter = datetime.now() - timedelta(days=7)
|
| 433 |
-
where_clause = " WHERE start_time >= ?"
|
| 434 |
-
params.append(date_filter.isoformat())
|
| 435 |
-
elif filter == "month":
|
| 436 |
-
date_filter = datetime.now() - timedelta(days=30)
|
| 437 |
-
where_clause = " WHERE start_time >= ?"
|
| 438 |
-
params.append(date_filter.isoformat())
|
| 439 |
-
elif filter == "all":
|
| 440 |
-
# Just ensure we only get completed sessions or all sessions
|
| 441 |
-
where_clause = " WHERE end_time IS NOT NULL"
|
| 442 |
-
|
| 443 |
-
query = f"{base_query}{where_clause} ORDER BY start_time DESC"
|
| 444 |
-
|
| 445 |
-
# Handle Limit for Exports
|
| 446 |
-
if limit == -1:
|
| 447 |
-
# No limit clause for export
|
| 448 |
-
pass
|
| 449 |
-
else:
|
| 450 |
-
query += f" {limit_clause}"
|
| 451 |
-
params.extend([limit, offset])
|
| 452 |
-
|
| 453 |
-
cursor = await db.execute(query, tuple(params))
|
| 454 |
-
rows = await cursor.fetchall()
|
| 455 |
-
return [dict(row) for row in rows]
|
| 456 |
-
|
| 457 |
-
# --- NEW: Import Endpoint ---
|
| 458 |
-
@app.post("/api/import")
|
| 459 |
-
async def import_sessions(sessions: List[dict]):
|
| 460 |
-
count = 0
|
| 461 |
-
try:
|
| 462 |
-
async with aiosqlite.connect(db_path) as db:
|
| 463 |
-
for session in sessions:
|
| 464 |
-
# Use .get() to handle potential missing fields from older versions or edits
|
| 465 |
-
await db.execute("""
|
| 466 |
-
INSERT INTO focus_sessions (start_time, end_time, duration_seconds, focus_score, total_frames, focused_frames, created_at)
|
| 467 |
-
VALUES (?, ?, ?, ?, ?, ?, ?)
|
| 468 |
-
""", (
|
| 469 |
-
session.get('start_time'),
|
| 470 |
-
session.get('end_time'),
|
| 471 |
-
session.get('duration_seconds', 0),
|
| 472 |
-
session.get('focus_score', 0.0),
|
| 473 |
-
session.get('total_frames', 0),
|
| 474 |
-
session.get('focused_frames', 0),
|
| 475 |
-
session.get('created_at', session.get('start_time'))
|
| 476 |
-
))
|
| 477 |
-
count += 1
|
| 478 |
-
await db.commit()
|
| 479 |
-
return {"status": "success", "count": count}
|
| 480 |
-
except Exception as e:
|
| 481 |
-
print(f"Import Error: {e}")
|
| 482 |
-
return {"status": "error", "message": str(e)}
|
| 483 |
-
|
| 484 |
-
# --- NEW: Clear History Endpoint ---
|
| 485 |
-
@app.delete("/api/history")
|
| 486 |
-
async def clear_history():
|
| 487 |
-
try:
|
| 488 |
-
async with aiosqlite.connect(db_path) as db:
|
| 489 |
-
# Delete events first (foreign key good practice)
|
| 490 |
-
await db.execute("DELETE FROM focus_events")
|
| 491 |
-
await db.execute("DELETE FROM focus_sessions")
|
| 492 |
-
await db.commit()
|
| 493 |
-
return {"status": "success", "message": "History cleared"}
|
| 494 |
-
except Exception as e:
|
| 495 |
-
return {"status": "error", "message": str(e)}
|
| 496 |
-
|
| 497 |
-
@app.get("/api/sessions/{session_id}")
|
| 498 |
-
async def get_session(session_id: int):
|
| 499 |
-
async with aiosqlite.connect(db_path) as db:
|
| 500 |
-
db.row_factory = aiosqlite.Row
|
| 501 |
-
cursor = await db.execute("SELECT * FROM focus_sessions WHERE id = ?", (session_id,))
|
| 502 |
-
row = await cursor.fetchone()
|
| 503 |
-
if not row: raise HTTPException(status_code=404, detail="Session not found")
|
| 504 |
-
session = dict(row)
|
| 505 |
-
cursor = await db.execute("SELECT * FROM focus_events WHERE session_id = ? ORDER BY timestamp", (session_id,))
|
| 506 |
-
events = [dict(r) for r in await cursor.fetchall()]
|
| 507 |
-
session['events'] = events
|
| 508 |
-
return session
|
| 509 |
-
|
| 510 |
-
@app.get("/api/settings")
|
| 511 |
-
async def get_settings():
|
| 512 |
-
async with aiosqlite.connect(db_path) as db:
|
| 513 |
-
db.row_factory = aiosqlite.Row
|
| 514 |
-
cursor = await db.execute("SELECT * FROM user_settings WHERE id = 1")
|
| 515 |
-
row = await cursor.fetchone()
|
| 516 |
-
if row: return dict(row)
|
| 517 |
-
else: return {'sensitivity': 6, 'notification_enabled': True, 'notification_threshold': 30, 'frame_rate': 30, 'model_name': 'yolov8n.pt'}
|
| 518 |
-
|
| 519 |
-
@app.put("/api/settings")
|
| 520 |
-
async def update_settings(settings: SettingsUpdate):
|
| 521 |
-
async with aiosqlite.connect(db_path) as db:
|
| 522 |
-
cursor = await db.execute("SELECT id FROM user_settings WHERE id = 1")
|
| 523 |
-
exists = await cursor.fetchone()
|
| 524 |
-
if not exists:
|
| 525 |
-
await db.execute("INSERT INTO user_settings (id, sensitivity) VALUES (1, 6)")
|
| 526 |
-
await db.commit()
|
| 527 |
-
|
| 528 |
-
updates = []
|
| 529 |
-
params = []
|
| 530 |
-
if settings.sensitivity is not None:
|
| 531 |
-
updates.append("sensitivity = ?")
|
| 532 |
-
params.append(max(1, min(10, settings.sensitivity)))
|
| 533 |
-
if settings.notification_enabled is not None:
|
| 534 |
-
updates.append("notification_enabled = ?")
|
| 535 |
-
params.append(settings.notification_enabled)
|
| 536 |
-
if settings.notification_threshold is not None:
|
| 537 |
-
updates.append("notification_threshold = ?")
|
| 538 |
-
params.append(max(5, min(300, settings.notification_threshold)))
|
| 539 |
-
if settings.frame_rate is not None:
|
| 540 |
-
updates.append("frame_rate = ?")
|
| 541 |
-
params.append(max(5, min(60, settings.frame_rate)))
|
| 542 |
-
|
| 543 |
-
if updates:
|
| 544 |
-
query = f"UPDATE user_settings SET {', '.join(updates)} WHERE id = 1"
|
| 545 |
-
await db.execute(query, params)
|
| 546 |
-
await db.commit()
|
| 547 |
-
return {"status": "success", "updated": len(updates) > 0}
|
| 548 |
-
|
| 549 |
-
@app.get("/api/stats/summary")
|
| 550 |
-
async def get_stats_summary():
|
| 551 |
-
async with aiosqlite.connect(db_path) as db:
|
| 552 |
-
cursor = await db.execute("SELECT COUNT(*) FROM focus_sessions WHERE end_time IS NOT NULL")
|
| 553 |
-
total_sessions = (await cursor.fetchone())[0]
|
| 554 |
-
cursor = await db.execute("SELECT SUM(duration_seconds) FROM focus_sessions WHERE end_time IS NOT NULL")
|
| 555 |
-
total_focus_time = (await cursor.fetchone())[0] or 0
|
| 556 |
-
cursor = await db.execute("SELECT AVG(focus_score) FROM focus_sessions WHERE end_time IS NOT NULL")
|
| 557 |
-
avg_focus_score = (await cursor.fetchone())[0] or 0.0
|
| 558 |
-
cursor = await db.execute("SELECT DISTINCT DATE(start_time) as session_date FROM focus_sessions WHERE end_time IS NOT NULL ORDER BY session_date DESC")
|
| 559 |
-
dates = [row[0] for row in await cursor.fetchall()]
|
| 560 |
-
|
| 561 |
-
streak_days = 0
|
| 562 |
-
if dates:
|
| 563 |
-
current_date = datetime.now().date()
|
| 564 |
-
for i, date_str in enumerate(dates):
|
| 565 |
-
session_date = datetime.fromisoformat(date_str).date()
|
| 566 |
-
expected_date = current_date - timedelta(days=i)
|
| 567 |
-
if session_date == expected_date: streak_days += 1
|
| 568 |
-
else: break
|
| 569 |
-
return {
|
| 570 |
-
'total_sessions': total_sessions,
|
| 571 |
-
'total_focus_time': int(total_focus_time),
|
| 572 |
-
'avg_focus_score': round(avg_focus_score, 3),
|
| 573 |
-
'streak_days': streak_days
|
| 574 |
-
}
|
| 575 |
-
|
| 576 |
-
@app.get("/health")
|
| 577 |
-
async def health_check():
|
| 578 |
-
return {"status": "healthy", "model_loaded": mlp_pipeline is not None, "database": os.path.exists(db_path)}
|
| 579 |
-
|
| 580 |
-
# ================ STATIC FILES (SPA SUPPORT) ================
|
| 581 |
-
|
| 582 |
-
FRONTEND_DIR = "dist" if os.path.exists("dist/index.html") else "static"
|
| 583 |
-
|
| 584 |
-
assets_path = os.path.join(FRONTEND_DIR, "assets")
|
| 585 |
-
if os.path.exists(assets_path):
|
| 586 |
-
app.mount("/assets", StaticFiles(directory=assets_path), name="assets")
|
| 587 |
-
|
| 588 |
-
@app.get("/{full_path:path}")
|
| 589 |
-
async def serve_react_app(full_path: str, request: Request):
|
| 590 |
-
if full_path.startswith("api") or full_path.startswith("ws"):
|
| 591 |
-
raise HTTPException(status_code=404, detail="Not Found")
|
| 592 |
-
|
| 593 |
-
file_path = os.path.join(FRONTEND_DIR, full_path)
|
| 594 |
-
if os.path.isfile(file_path):
|
| 595 |
-
return FileResponse(file_path)
|
| 596 |
-
|
| 597 |
-
index_path = os.path.join(FRONTEND_DIR, "index.html")
|
| 598 |
-
if os.path.exists(index_path):
|
| 599 |
-
return FileResponse(index_path)
|
| 600 |
-
else:
|
| 601 |
-
return {"message": "React app not found. Please run npm run build."}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|