Kexin-251202 commited on
Commit
c86c45b
·
verified ·
1 Parent(s): 49eb313

Deploy base model

Browse files

changed main.py(Fast API, delete yolo), Dockerfile, requirements

.gitignore ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Logs
2
+ logs
3
+ *.log
4
+ npm-debug.log*
5
+ yarn-debug.log*
6
+ yarn-error.log*
7
+ pnpm-debug.log*
8
+ lerna-debug.log*
9
+
10
+ node_modules
11
+ dist
12
+ dist-ssr
13
+ *.local
14
+
15
+ # Editor directories and files
16
+ .vscode/*
17
+ !.vscode/extensions.json
18
+ .idea
19
+ .DS_Store
20
+ *.suo
21
+ *.ntvs*
22
+ *.njsproj
23
+ *.sln
24
+ *.sw?
Dockerfile ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ RUN useradd -m -u 1000 user
4
+ ENV HOME=/home/user PATH=/home/user/.local/bin:$PATH
5
+
6
+ WORKDIR /app
7
+
8
+ RUN apt-get update && apt-get install -y --no-install-recommends libglib2.0-0 libsm6 libxrender1 libxext6 libxcb1 libgl1 libgomp1 ffmpeg libavcodec-dev libavformat-dev libavutil-dev libswscale-dev libavdevice-dev libopus-dev libvpx-dev libsrtp2-dev curl build-essential && rm -rf /var/lib/apt/lists/*
9
+
10
+ RUN curl -fsSL | bash - && apt-get install -y --no-install-recommends nodejs && rm -rf /var/lib/apt/lists/*
11
+
12
+ COPY requirements.txt ./
13
+ RUN pip install --no-cache-dir -r requirements.txt
14
+
15
+ ENV FOCUSGUARD_CACHE_DIR=/app/.cache/focusguard
16
+ RUN python -c "from models.face_mesh import _ensure_model; _ensure_model()"
17
+
18
+ RUN mkdir -p /app/data && chown -R user:user /app
19
+
20
+ USER user
21
+ EXPOSE 7860
22
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
api/history ADDED
File without changes
api/import ADDED
File without changes
api/sessions ADDED
File without changes
app.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from main import app
checkpoints/gru_best.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cb913f877e7976dea927d1f05503a4ca8ed423720688fc1f911595dcd6260746
3
+ size 170769
checkpoints/gru_meta_best.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aa3424cb87c51e9ad30e6fc0139abd095b201c91bcde47682749886739b360d9
3
+ size 3282
checkpoints/gru_scaler_best.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0e308ab172128506873b811b0cdea5a90fcde2ecff170d792d101c77b80d63e5
3
+ size 648
checkpoints/hybrid_focus_config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "w_mlp": 0.6000000000000001,
3
+ "w_geo": 0.3999999999999999,
4
+ "threshold": 0.35,
5
+ "use_yawn_veto": true,
6
+ "geo_face_weight": 0.4,
7
+ "geo_eye_weight": 0.6,
8
+ "mar_yawn_threshold": 0.55,
9
+ "metric": "f1"
10
+ }
checkpoints/meta_best.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5d78d1df5e25536a2c82c4b8f5fd0c26dd35f44b28fd59761634cbf78c7546f8
3
+ size 4196
checkpoints/model_best.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:183f2d4419e0eb1e58704e5a7312eb61e331523566d4dc551054a07b3aac7557
3
+ size 5775881
checkpoints/scaler_best.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:02ed6b4c0d99e0254c6a740a949da2384db58ec7d3e6df6432b9bfcd3a296c71
3
+ size 783
docker-compose.yml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ services:
2
+ focus-guard:
3
+ build: .
4
+ ports:
5
+ - "7860:7860"
eslint.config.js ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import js from '@eslint/js'
2
+ import globals from 'globals'
3
+ import reactHooks from 'eslint-plugin-react-hooks'
4
+ import reactRefresh from 'eslint-plugin-react-refresh'
5
+ import { defineConfig, globalIgnores } from 'eslint/config'
6
+
7
+ export default defineConfig([
8
+ globalIgnores(['dist']),
9
+ {
10
+ files: ['**/*.{js,jsx}'],
11
+ extends: [
12
+ js.configs.recommended,
13
+ reactHooks.configs.flat.recommended,
14
+ reactRefresh.configs.vite,
15
+ ],
16
+ languageOptions: {
17
+ ecmaVersion: 2020,
18
+ globals: globals.browser,
19
+ parserOptions: {
20
+ ecmaVersion: 'latest',
21
+ ecmaFeatures: { jsx: true },
22
+ sourceType: 'module',
23
+ },
24
+ },
25
+ rules: {
26
+ 'no-unused-vars': ['error', { varsIgnorePattern: '^[A-Z_]' }],
27
+ },
28
+ },
29
+ ])
index.html ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!doctype html>
2
+ <html lang="en">
3
+
4
+ <head>
5
+ <meta charset="UTF-8" />
6
+ <link rel="icon" type="image/svg+xml" href="/vite.svg" />
7
+ <meta name="viewport" content="width=device-width, initial-scale=1.0" />
8
+ <title>哈哈卧槽几点了</title>
9
+ <link href="https://fonts.googleapis.com/css2?family=Nunito:wght@400;700&display=swap" rel="stylesheet">
10
+ </head>
11
+
12
+ <body>
13
+ <div id="root"></div>
14
+ <script type="module" src="/src/main.jsx"></script>
15
+ </body>
16
+
17
+ </html>
main.py ADDED
@@ -0,0 +1,598 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException, Request
2
+ from fastapi.staticfiles import StaticFiles
3
+ from fastapi.responses import FileResponse
4
+ from fastapi.middleware.cors import CORSMiddleware
5
+ from pydantic import BaseModel
6
+ from typing import Optional, List, Any
7
+ import base64
8
+ import cv2
9
+ import numpy as np
10
+ import aiosqlite
11
+ import json
12
+ from datetime import datetime, timedelta
13
+ import math
14
+ import os
15
+ from pathlib import Path
16
+ from typing import Callable
17
+ import asyncio
18
+
19
+ from aiortc import RTCPeerConnection, RTCSessionDescription, VideoStreamTrack
20
+ from av import VideoFrame
21
+
22
+ from ui.pipeline import MLPPipeline
23
+
24
+ # Initialize FastAPI app
25
+ app = FastAPI(title="Focus Guard API")
26
+
27
+ # Add CORS middleware
28
+ app.add_middleware(
29
+ CORSMiddleware,
30
+ allow_origins=["*"],
31
+ allow_credentials=True,
32
+ allow_methods=["*"],
33
+ allow_headers=["*"],
34
+ )
35
+
36
+ # Global variables
37
+ db_path = "focus_guard.db"
38
+ pcs = set()
39
+
40
+ async def _wait_for_ice_gathering(pc: RTCPeerConnection):
41
+ if pc.iceGatheringState == "complete":
42
+ return
43
+ done = asyncio.Event()
44
+
45
+ @pc.on("icegatheringstatechange")
46
+ def _on_state_change():
47
+ if pc.iceGatheringState == "complete":
48
+ done.set()
49
+
50
+ await done.wait()
51
+
52
+ # ================ DATABASE MODELS ================
53
+
54
+ async def init_database():
55
+ """Initialize SQLite database with required tables"""
56
+ async with aiosqlite.connect(db_path) as db:
57
+ # FocusSessions table
58
+ await db.execute("""
59
+ CREATE TABLE IF NOT EXISTS focus_sessions (
60
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
61
+ start_time TIMESTAMP NOT NULL,
62
+ end_time TIMESTAMP,
63
+ duration_seconds INTEGER DEFAULT 0,
64
+ focus_score REAL DEFAULT 0.0,
65
+ total_frames INTEGER DEFAULT 0,
66
+ focused_frames INTEGER DEFAULT 0,
67
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
68
+ )
69
+ """)
70
+
71
+ # FocusEvents table
72
+ await db.execute("""
73
+ CREATE TABLE IF NOT EXISTS focus_events (
74
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
75
+ session_id INTEGER NOT NULL,
76
+ timestamp TIMESTAMP NOT NULL,
77
+ is_focused BOOLEAN NOT NULL,
78
+ confidence REAL NOT NULL,
79
+ detection_data TEXT,
80
+ FOREIGN KEY (session_id) REFERENCES focus_sessions (id)
81
+ )
82
+ """)
83
+
84
+ # UserSettings table
85
+ await db.execute("""
86
+ CREATE TABLE IF NOT EXISTS user_settings (
87
+ id INTEGER PRIMARY KEY CHECK (id = 1),
88
+ sensitivity INTEGER DEFAULT 6,
89
+ notification_enabled BOOLEAN DEFAULT 1,
90
+ notification_threshold INTEGER DEFAULT 30,
91
+ frame_rate INTEGER DEFAULT 30,
92
+ model_name TEXT DEFAULT 'yolov8n.pt'
93
+ )
94
+ """)
95
+
96
+ # Insert default settings if not exists
97
+ await db.execute("""
98
+ INSERT OR IGNORE INTO user_settings (id, sensitivity, notification_enabled, notification_threshold, frame_rate, model_name)
99
+ VALUES (1, 6, 1, 30, 30, 'yolov8n.pt')
100
+ """)
101
+
102
+ await db.commit()
103
+
104
+ # ================ PYDANTIC MODELS ================
105
+
106
+ class SessionCreate(BaseModel):
107
+ pass
108
+
109
+ class SessionEnd(BaseModel):
110
+ session_id: int
111
+
112
+ class SettingsUpdate(BaseModel):
113
+ sensitivity: Optional[int] = None
114
+ notification_enabled: Optional[bool] = None
115
+ notification_threshold: Optional[int] = None
116
+ frame_rate: Optional[int] = None
117
+
118
+ class VideoTransformTrack(VideoStreamTrack):
119
+ def __init__(self, track, session_id: int, get_channel: Callable[[], Any]):
120
+ super().__init__()
121
+ self.track = track
122
+ self.session_id = session_id
123
+ self.get_channel = get_channel
124
+ self.last_inference_time = 0
125
+ self.min_inference_interval = 1 / 60
126
+ self.last_frame = None
127
+
128
+ async def recv(self):
129
+ frame = await self.track.recv()
130
+ img = frame.to_ndarray(format="bgr24")
131
+ if img is None:
132
+ return frame
133
+
134
+ # Normalize size for inference/drawing
135
+ img = cv2.resize(img, (640, 480))
136
+
137
+ now = datetime.now().timestamp()
138
+ do_infer = (now - self.last_inference_time) >= self.min_inference_interval
139
+
140
+ if do_infer and mlp_pipeline is not None:
141
+ self.last_inference_time = now
142
+ out = mlp_pipeline.process_frame(img)detections = parse_yolo_results(results)
143
+ is_focused = out["is_focused"]
144
+ confidence = out["mlp_prob"]
145
+ metadata = {"s_face": out["s_face"], "s_eye": out["s_eye"], "mar": out["mar"]}
146
+ detections = []
147
+ status_text = "FOCUSED" if is_focused else "NOT FOCUSED"
148
+ color = (0, 255, 0) if is_focused else (0, 0, 255)
149
+ cv2.putText(img, status_text, (10, 25), cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2)
150
+ cv2.putText(img, f"Confidence: {confidence * 100:.1f}%", (10, 50), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
151
+
152
+ if self.session_id:
153
+ await store_focus_event(self.session_id, is_focused, confidence, metadata)
154
+
155
+ channel = self.get_channel()
156
+ if channel and channel.readyState == "open":
157
+ try:
158
+ channel.send(json.dumps({"type": "detection", "focused": is_focused, "confidence": round(confidence, 3), "detections": detections}))
159
+ except Exception:
160
+ pass
161
+
162
+ self.last_frame = img
163
+ elif self.last_frame is not None:
164
+ img = self.last_frame
165
+
166
+ new_frame = VideoFrame.from_ndarray(img, format="bgr24")
167
+ new_frame.pts = frame.pts
168
+ new_frame.time_base = frame.time_base
169
+ return new_frame
170
+
171
+ # ================ DATABASE OPERATIONS ================
172
+
173
+ async def create_session():
174
+ async with aiosqlite.connect(db_path) as db:
175
+ cursor = await db.execute(
176
+ "INSERT INTO focus_sessions (start_time) VALUES (?)",
177
+ (datetime.now().isoformat(),)
178
+ )
179
+ await db.commit()
180
+ return cursor.lastrowid
181
+
182
+ async def end_session(session_id: int):
183
+ async with aiosqlite.connect(db_path) as db:
184
+ cursor = await db.execute(
185
+ "SELECT start_time, total_frames, focused_frames FROM focus_sessions WHERE id = ?",
186
+ (session_id,)
187
+ )
188
+ row = await cursor.fetchone()
189
+
190
+ if not row:
191
+ return None
192
+
193
+ start_time_str, total_frames, focused_frames = row
194
+ start_time = datetime.fromisoformat(start_time_str)
195
+ end_time = datetime.now()
196
+ duration = (end_time - start_time).total_seconds()
197
+ focus_score = focused_frames / total_frames if total_frames > 0 else 0.0
198
+
199
+ await db.execute("""
200
+ UPDATE focus_sessions
201
+ SET end_time = ?, duration_seconds = ?, focus_score = ?
202
+ WHERE id = ?
203
+ """, (end_time.isoformat(), int(duration), focus_score, session_id))
204
+
205
+ await db.commit()
206
+
207
+ return {
208
+ 'session_id': session_id,
209
+ 'start_time': start_time_str,
210
+ 'end_time': end_time.isoformat(),
211
+ 'duration_seconds': int(duration),
212
+ 'focus_score': round(focus_score, 3),
213
+ 'total_frames': total_frames,
214
+ 'focused_frames': focused_frames
215
+ }
216
+
217
+ async def store_focus_event(session_id: int, is_focused: bool, confidence: float, metadata: dict):
218
+ async with aiosqlite.connect(db_path) as db:
219
+ await db.execute("""
220
+ INSERT INTO focus_events (session_id, timestamp, is_focused, confidence, detection_data)
221
+ VALUES (?, ?, ?, ?, ?)
222
+ """, (session_id, datetime.now().isoformat(), is_focused, confidence, json.dumps(metadata)))
223
+
224
+ await db.execute(f"""
225
+ UPDATE focus_sessions
226
+ SET total_frames = total_frames + 1,
227
+ focused_frames = focused_frames + {1 if is_focused else 0}
228
+ WHERE id = ?
229
+ """, (session_id,))
230
+ await db.commit()
231
+
232
+ # ================ STARTUP/SHUTDOWN ================
233
+
234
+ mlp_pipeline = None
235
+
236
+ @app.on_event("startup")
237
+ async def startup_event():
238
+ global mlp_pipeline
239
+ print(" Starting Focus Guard API...")
240
+ await init_database()
241
+ print("[OK] Database initialized")
242
+
243
+ mlp_pipeline = MLPPipeline()
244
+ print("[OK] MLPPipeline loaded")
245
+
246
+ @app.on_event("shutdown")
247
+ async def shutdown_event():
248
+ print(" Shutting down Focus Guard API...")
249
+
250
+ # ================ WEBRTC SIGNALING ================
251
+
252
+ @app.post("/api/webrtc/offer")
253
+ async def webrtc_offer(offer: dict):
254
+ try:
255
+ print(f"Received WebRTC offer")
256
+
257
+ pc = RTCPeerConnection()
258
+ pcs.add(pc)
259
+
260
+ session_id = await create_session()
261
+ print(f"Created session: {session_id}")
262
+
263
+ channel_ref = {"channel": None}
264
+
265
+ @pc.on("datachannel")
266
+ def on_datachannel(channel):
267
+ print(f"Data channel opened")
268
+ channel_ref["channel"] = channel
269
+
270
+ @pc.on("track")
271
+ def on_track(track):
272
+ print(f"Received track: {track.kind}")
273
+ if track.kind == "video":
274
+ local_track = VideoTransformTrack(track, session_id, lambda: channel_ref["channel"])
275
+ pc.addTrack(local_track)
276
+ print(f"Video track added")
277
+
278
+ @track.on("ended")
279
+ async def on_ended():
280
+ print(f"Track ended")
281
+
282
+ @pc.on("connectionstatechange")
283
+ async def on_connectionstatechange():
284
+ print(f"Connection state changed: {pc.connectionState}")
285
+ if pc.connectionState in ("failed", "closed", "disconnected"):
286
+ try:
287
+ await end_session(session_id)
288
+ except Exception as e:
289
+ print(f"⚠Error ending session: {e}")
290
+ pcs.discard(pc)
291
+ await pc.close()
292
+
293
+ await pc.setRemoteDescription(RTCSessionDescription(sdp=offer["sdp"], type=offer["type"]))
294
+ print(f"Remote description set")
295
+
296
+ answer = await pc.createAnswer()
297
+ await pc.setLocalDescription(answer)
298
+ print(f"Answer created")
299
+
300
+ await _wait_for_ice_gathering(pc)
301
+ print(f"ICE gathering complete")
302
+
303
+ return {"sdp": pc.localDescription.sdp, "type": pc.localDescription.type, "session_id": session_id}
304
+
305
+ except Exception as e:
306
+ print(f"WebRTC offer error: {e}")
307
+ import traceback
308
+ traceback.print_exc()
309
+ raise HTTPException(status_code=500, detail=f"WebRTC error: {str(e)}")
310
+
311
+ # ================ WEBSOCKET ================
312
+
313
+ @app.websocket("/ws/video")
314
+ async def websocket_endpoint(websocket: WebSocket):
315
+ await websocket.accept()
316
+ session_id = None
317
+ frame_count = 0
318
+ last_inference_time = 0
319
+ min_inference_interval = 1 / 60
320
+
321
+ try:
322
+ async with aiosqlite.connect(db_path) as db:
323
+ cursor = await db.execute("SELECT sensitivity FROM user_settings WHERE id = 1")
324
+ row = await cursor.fetchone()
325
+ sensitivity = row[0] if row else 6
326
+
327
+ while True:
328
+ data = await websocket.receive_json()
329
+
330
+ if data['type'] == 'frame':
331
+ from time import time
332
+ current_time = time()
333
+ if current_time - last_inference_time < min_inference_interval:
334
+ await websocket.send_json({'type': 'ack', 'frame_count': frame_count})
335
+ continue
336
+ last_inference_time = current_time
337
+
338
+ try:
339
+ img_data = base64.b64decode(data['image'])
340
+ nparr = np.frombuffer(img_data, np.uint8)
341
+ frame = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
342
+
343
+ if frame is None: continue
344
+ frame = cv2.resize(frame, (640, 480))
345
+
346
+ if mlp_pipeline is not None:
347
+ out = mlp_pipeline.process_frame(frame)
348
+
349
+ is_focused = out["is_focused"]
350
+ confidence = out["mlp_prob"]
351
+ metadata = {
352
+ "s_face": out["s_face"],
353
+ "s_eye": out["s_eye"],
354
+ "mar": out["mar"]
355
+ }
356
+ else:
357
+ is_focused = False
358
+ confidence = 0.0
359
+ metadata = {}
360
+
361
+ detections = []
362
+
363
+ if session_id:
364
+ await store_focus_event(session_id, is_focused, confidence, metadata)
365
+
366
+ await websocket.send_json({
367
+ 'type': 'detection',
368
+ 'focused': is_focused,
369
+ 'confidence': round(confidence, 3),
370
+ 'detections': detections,
371
+ 'frame_count': frame_count
372
+ })
373
+ frame_count += 1
374
+ except Exception as e:
375
+ print(f"Error processing frame: {e}")
376
+ await websocket.send_json({'type': 'error', 'message': str(e)})
377
+
378
+ elif data['type'] == 'start_session':
379
+ session_id = await create_session()
380
+ await websocket.send_json({'type': 'session_started', 'session_id': session_id})
381
+
382
+ elif data['type'] == 'end_session':
383
+ if session_id:
384
+ print(f"Ending session {session_id}...")
385
+ summary = await end_session(session_id)
386
+ print(f"Session summary: {summary}")
387
+ if summary:
388
+ await websocket.send_json({'type': 'session_ended', 'summary': summary})
389
+ print("Session ended message sent")
390
+ else:
391
+ print("Warning: No summary returned")
392
+ session_id = None
393
+ else:
394
+ print("Warning: end_session called but no active session_id")
395
+
396
+ except WebSocketDisconnect:
397
+ if session_id: await end_session(session_id)
398
+ except Exception as e:
399
+ if websocket.client_state.value == 1: await websocket.close()
400
+
401
+ # ================ API ENDPOINTS ================
402
+
403
+ @app.post("/api/sessions/start")
404
+ async def api_start_session():
405
+ session_id = await create_session()
406
+ return {"session_id": session_id}
407
+
408
+ @app.post("/api/sessions/end")
409
+ async def api_end_session(data: SessionEnd):
410
+ summary = await end_session(data.session_id)
411
+ if not summary: raise HTTPException(status_code=404, detail="Session not found")
412
+ return summary
413
+
414
+ @app.get("/api/sessions")
415
+ async def get_sessions(filter: str = "all", limit: int = 50, offset: int = 0):
416
+ async with aiosqlite.connect(db_path) as db:
417
+ db.row_factory = aiosqlite.Row
418
+
419
+ # NEW: If importing/exporting all, remove limit if special flag or high limit
420
+ # For simplicity: if limit is -1, return all
421
+ limit_clause = "LIMIT ? OFFSET ?"
422
+ params = []
423
+
424
+ base_query = "SELECT * FROM focus_sessions"
425
+ where_clause = ""
426
+
427
+ if filter == "today":
428
+ date_filter = datetime.now().replace(hour=0, minute=0, second=0, microsecond=0)
429
+ where_clause = " WHERE start_time >= ?"
430
+ params.append(date_filter.isoformat())
431
+ elif filter == "week":
432
+ date_filter = datetime.now() - timedelta(days=7)
433
+ where_clause = " WHERE start_time >= ?"
434
+ params.append(date_filter.isoformat())
435
+ elif filter == "month":
436
+ date_filter = datetime.now() - timedelta(days=30)
437
+ where_clause = " WHERE start_time >= ?"
438
+ params.append(date_filter.isoformat())
439
+ elif filter == "all":
440
+ # Just ensure we only get completed sessions or all sessions
441
+ where_clause = " WHERE end_time IS NOT NULL"
442
+
443
+ query = f"{base_query}{where_clause} ORDER BY start_time DESC"
444
+
445
+ # Handle Limit for Exports
446
+ if limit == -1:
447
+ # No limit clause for export
448
+ pass
449
+ else:
450
+ query += f" {limit_clause}"
451
+ params.extend([limit, offset])
452
+
453
+ cursor = await db.execute(query, tuple(params))
454
+ rows = await cursor.fetchall()
455
+ return [dict(row) for row in rows]
456
+
457
+ # --- NEW: Import Endpoint ---
458
+ @app.post("/api/import")
459
+ async def import_sessions(sessions: List[dict]):
460
+ count = 0
461
+ try:
462
+ async with aiosqlite.connect(db_path) as db:
463
+ for session in sessions:
464
+ # Use .get() to handle potential missing fields from older versions or edits
465
+ await db.execute("""
466
+ INSERT INTO focus_sessions (start_time, end_time, duration_seconds, focus_score, total_frames, focused_frames, created_at)
467
+ VALUES (?, ?, ?, ?, ?, ?, ?)
468
+ """, (
469
+ session.get('start_time'),
470
+ session.get('end_time'),
471
+ session.get('duration_seconds', 0),
472
+ session.get('focus_score', 0.0),
473
+ session.get('total_frames', 0),
474
+ session.get('focused_frames', 0),
475
+ session.get('created_at', session.get('start_time'))
476
+ ))
477
+ count += 1
478
+ await db.commit()
479
+ return {"status": "success", "count": count}
480
+ except Exception as e:
481
+ print(f"Import Error: {e}")
482
+ return {"status": "error", "message": str(e)}
483
+
484
+ # --- NEW: Clear History Endpoint ---
485
+ @app.delete("/api/history")
486
+ async def clear_history():
487
+ try:
488
+ async with aiosqlite.connect(db_path) as db:
489
+ # Delete events first (foreign key good practice)
490
+ await db.execute("DELETE FROM focus_events")
491
+ await db.execute("DELETE FROM focus_sessions")
492
+ await db.commit()
493
+ return {"status": "success", "message": "History cleared"}
494
+ except Exception as e:
495
+ return {"status": "error", "message": str(e)}
496
+
497
+ @app.get("/api/sessions/{session_id}")
498
+ async def get_session(session_id: int):
499
+ async with aiosqlite.connect(db_path) as db:
500
+ db.row_factory = aiosqlite.Row
501
+ cursor = await db.execute("SELECT * FROM focus_sessions WHERE id = ?", (session_id,))
502
+ row = await cursor.fetchone()
503
+ if not row: raise HTTPException(status_code=404, detail="Session not found")
504
+ session = dict(row)
505
+ cursor = await db.execute("SELECT * FROM focus_events WHERE session_id = ? ORDER BY timestamp", (session_id,))
506
+ events = [dict(r) for r in await cursor.fetchall()]
507
+ session['events'] = events
508
+ return session
509
+
510
+ @app.get("/api/settings")
511
+ async def get_settings():
512
+ async with aiosqlite.connect(db_path) as db:
513
+ db.row_factory = aiosqlite.Row
514
+ cursor = await db.execute("SELECT * FROM user_settings WHERE id = 1")
515
+ row = await cursor.fetchone()
516
+ if row: return dict(row)
517
+ else: return {'sensitivity': 6, 'notification_enabled': True, 'notification_threshold': 30, 'frame_rate': 30, 'model_name': 'yolov8n.pt'}
518
+
519
+ @app.put("/api/settings")
520
+ async def update_settings(settings: SettingsUpdate):
521
+ async with aiosqlite.connect(db_path) as db:
522
+ cursor = await db.execute("SELECT id FROM user_settings WHERE id = 1")
523
+ exists = await cursor.fetchone()
524
+ if not exists:
525
+ await db.execute("INSERT INTO user_settings (id, sensitivity) VALUES (1, 6)")
526
+ await db.commit()
527
+
528
+ updates = []
529
+ params = []
530
+ if settings.sensitivity is not None:
531
+ updates.append("sensitivity = ?")
532
+ params.append(max(1, min(10, settings.sensitivity)))
533
+ if settings.notification_enabled is not None:
534
+ updates.append("notification_enabled = ?")
535
+ params.append(settings.notification_enabled)
536
+ if settings.notification_threshold is not None:
537
+ updates.append("notification_threshold = ?")
538
+ params.append(max(5, min(300, settings.notification_threshold)))
539
+ if settings.frame_rate is not None:
540
+ updates.append("frame_rate = ?")
541
+ params.append(max(5, min(60, settings.frame_rate)))
542
+
543
+ if updates:
544
+ query = f"UPDATE user_settings SET {', '.join(updates)} WHERE id = 1"
545
+ await db.execute(query, params)
546
+ await db.commit()
547
+ return {"status": "success", "updated": len(updates) > 0}
548
+
549
+ @app.get("/api/stats/summary")
550
+ async def get_stats_summary():
551
+ async with aiosqlite.connect(db_path) as db:
552
+ cursor = await db.execute("SELECT COUNT(*) FROM focus_sessions WHERE end_time IS NOT NULL")
553
+ total_sessions = (await cursor.fetchone())[0]
554
+ cursor = await db.execute("SELECT SUM(duration_seconds) FROM focus_sessions WHERE end_time IS NOT NULL")
555
+ total_focus_time = (await cursor.fetchone())[0] or 0
556
+ cursor = await db.execute("SELECT AVG(focus_score) FROM focus_sessions WHERE end_time IS NOT NULL")
557
+ avg_focus_score = (await cursor.fetchone())[0] or 0.0
558
+ cursor = await db.execute("SELECT DISTINCT DATE(start_time) as session_date FROM focus_sessions WHERE end_time IS NOT NULL ORDER BY session_date DESC")
559
+ dates = [row[0] for row in await cursor.fetchall()]
560
+
561
+ streak_days = 0
562
+ if dates:
563
+ current_date = datetime.now().date()
564
+ for i, date_str in enumerate(dates):
565
+ session_date = datetime.fromisoformat(date_str).date()
566
+ expected_date = current_date - timedelta(days=i)
567
+ if session_date == expected_date: streak_days += 1
568
+ else: break
569
+ return {
570
+ 'total_sessions': total_sessions,
571
+ 'total_focus_time': int(total_focus_time),
572
+ 'avg_focus_score': round(avg_focus_score, 3),
573
+ 'streak_days': streak_days
574
+ }
575
+
576
+ @app.get("/health")
577
+ async def health_check():
578
+ return {"status": "healthy", "model_loaded": mlp_pipeline is not None, "database": os.path.exists(db_path)}
579
+
580
+ # ================ STATIC FILES (SPA SUPPORT) ================
581
+
582
+ # 1. Mount the assets folder (JS/CSS built by Vite/React)
583
+ if os.path.exists("static/assets"):
584
+ app.mount("/assets", StaticFiles(directory="static/assets"), name="assets")
585
+
586
+ # 2. Catch-all route for SPA (React Router)
587
+ # This ensures that if you refresh /customise, it serves index.html instead of 404
588
+ @app.get("/{full_path:path}")
589
+ async def serve_react_app(full_path: str, request: Request):
590
+ # Skip API and WS routes
591
+ if full_path.startswith("api") or full_path.startswith("ws"):
592
+ raise HTTPException(status_code=404, detail="Not Found")
593
+
594
+ # Serve index.html for any other route
595
+ if os.path.exists("static/index.html"):
596
+ return FileResponse("static/index.html")
597
+ else:
598
+ return {"message": "React app not found. Please run 'npm run build' and copy dist to static."}
models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
models/collect_features.py ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import argparse
3
+ import collections
4
+ import math
5
+ import os
6
+ import sys
7
+ import time
8
+
9
+ import cv2
10
+ import numpy as np
11
+
12
+ _PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
13
+ if _PROJECT_ROOT not in sys.path:
14
+ sys.path.insert(0, _PROJECT_ROOT)
15
+
16
+ from models.face_mesh import FaceMeshDetector
17
+ from models.head_pose import HeadPoseEstimator
18
+ from models.eye_scorer import EyeBehaviourScorer, compute_gaze_ratio, compute_mar
19
+
20
+ FONT = cv2.FONT_HERSHEY_SIMPLEX
21
+ GREEN = (0, 255, 0)
22
+ RED = (0, 0, 255)
23
+ WHITE = (255, 255, 255)
24
+ YELLOW = (0, 255, 255)
25
+ ORANGE = (0, 165, 255)
26
+ GRAY = (120, 120, 120)
27
+
28
+ FEATURE_NAMES = [
29
+ "ear_left", "ear_right", "ear_avg", "h_gaze", "v_gaze", "mar",
30
+ "yaw", "pitch", "roll", "s_face", "s_eye", "gaze_offset", "head_deviation",
31
+ "perclos", "blink_rate", "closure_duration", "yawn_duration",
32
+ ]
33
+
34
+ NUM_FEATURES = len(FEATURE_NAMES)
35
+ assert NUM_FEATURES == 17
36
+
37
+
38
+ class TemporalTracker:
39
+ EAR_BLINK_THRESH = 0.21
40
+ MAR_YAWN_THRESH = 0.55
41
+ PERCLOS_WINDOW = 60
42
+ BLINK_WINDOW_SEC = 30.0
43
+
44
+ def __init__(self):
45
+ self.ear_history = collections.deque(maxlen=self.PERCLOS_WINDOW)
46
+ self.blink_timestamps = collections.deque()
47
+ self._eyes_closed = False
48
+ self._closure_start = None
49
+ self._yawn_start = None
50
+
51
+ def update(self, ear_avg, mar, now=None):
52
+ if now is None:
53
+ now = time.time()
54
+
55
+ closed = ear_avg < self.EAR_BLINK_THRESH
56
+ self.ear_history.append(1.0 if closed else 0.0)
57
+ perclos = sum(self.ear_history) / len(self.ear_history) if self.ear_history else 0.0
58
+
59
+ if self._eyes_closed and not closed:
60
+ self.blink_timestamps.append(now)
61
+ self._eyes_closed = closed
62
+
63
+ cutoff = now - self.BLINK_WINDOW_SEC
64
+ while self.blink_timestamps and self.blink_timestamps[0] < cutoff:
65
+ self.blink_timestamps.popleft()
66
+ blink_rate = len(self.blink_timestamps) * (60.0 / self.BLINK_WINDOW_SEC)
67
+
68
+ if closed:
69
+ if self._closure_start is None:
70
+ self._closure_start = now
71
+ closure_dur = now - self._closure_start
72
+ else:
73
+ self._closure_start = None
74
+ closure_dur = 0.0
75
+
76
+ yawning = mar > self.MAR_YAWN_THRESH
77
+ if yawning:
78
+ if self._yawn_start is None:
79
+ self._yawn_start = now
80
+ yawn_dur = now - self._yawn_start
81
+ else:
82
+ self._yawn_start = None
83
+ yawn_dur = 0.0
84
+
85
+ return perclos, blink_rate, closure_dur, yawn_dur
86
+
87
+
88
+ def extract_features(landmarks, w, h, head_pose, eye_scorer, temporal,
89
+ *, _pre=None):
90
+ from models.eye_scorer import _LEFT_EYE_EAR, _RIGHT_EYE_EAR, compute_ear
91
+
92
+ p = _pre or {}
93
+
94
+ ear_left = p.get("ear_left", compute_ear(landmarks, _LEFT_EYE_EAR))
95
+ ear_right = p.get("ear_right", compute_ear(landmarks, _RIGHT_EYE_EAR))
96
+ ear_avg = (ear_left + ear_right) / 2.0
97
+
98
+ if "h_gaze" in p and "v_gaze" in p:
99
+ h_gaze, v_gaze = p["h_gaze"], p["v_gaze"]
100
+ else:
101
+ h_gaze, v_gaze = compute_gaze_ratio(landmarks)
102
+
103
+ mar = p.get("mar", compute_mar(landmarks))
104
+
105
+ angles = p.get("angles")
106
+ if angles is None:
107
+ angles = head_pose.estimate(landmarks, w, h)
108
+ yaw = angles[0] if angles else 0.0
109
+ pitch = angles[1] if angles else 0.0
110
+ roll = angles[2] if angles else 0.0
111
+
112
+ s_face = p.get("s_face", head_pose.score(landmarks, w, h))
113
+ s_eye = p.get("s_eye", eye_scorer.score(landmarks))
114
+
115
+ gaze_offset = math.sqrt((h_gaze - 0.5) ** 2 + (v_gaze - 0.5) ** 2)
116
+ head_deviation = math.sqrt(yaw ** 2 + pitch ** 2) # cleaned downstream
117
+
118
+ perclos, blink_rate, closure_dur, yawn_dur = temporal.update(ear_avg, mar)
119
+
120
+ return np.array([
121
+ ear_left, ear_right, ear_avg,
122
+ h_gaze, v_gaze,
123
+ mar,
124
+ yaw, pitch, roll,
125
+ s_face, s_eye,
126
+ gaze_offset,
127
+ head_deviation,
128
+ perclos, blink_rate, closure_dur, yawn_dur,
129
+ ], dtype=np.float32)
130
+
131
+
132
+ def quality_report(labels):
133
+ n = len(labels)
134
+ n1 = int((labels == 1).sum())
135
+ n0 = n - n1
136
+ transitions = int(np.sum(np.diff(labels) != 0))
137
+ duration_sec = n / 30.0 # approximate at 30fps
138
+
139
+ warnings = []
140
+
141
+ print(f"\n{'='*50}")
142
+ print(f" DATA QUALITY REPORT")
143
+ print(f"{'='*50}")
144
+ print(f" Total samples : {n}")
145
+ print(f" Focused : {n1} ({n1/max(n,1)*100:.1f}%)")
146
+ print(f" Unfocused : {n0} ({n0/max(n,1)*100:.1f}%)")
147
+ print(f" Duration : {duration_sec:.0f}s ({duration_sec/60:.1f} min)")
148
+ print(f" Transitions : {transitions}")
149
+ if transitions > 0:
150
+ print(f" Avg segment : {n/transitions:.0f} frames ({n/transitions/30:.1f}s)")
151
+
152
+ # checks
153
+ if duration_sec < 120:
154
+ warnings.append(f"TOO SHORT: {duration_sec:.0f}s — aim for 5-10 minutes (300-600s)")
155
+
156
+ if n < 3000:
157
+ warnings.append(f"LOW SAMPLE COUNT: {n} frames — aim for 9000+ (5 min at 30fps)")
158
+
159
+ balance = n1 / max(n, 1)
160
+ if balance < 0.3 or balance > 0.7:
161
+ warnings.append(f"IMBALANCED: {balance:.0%} focused — aim for 35-65% focused")
162
+
163
+ if transitions < 10:
164
+ warnings.append(f"TOO FEW TRANSITIONS: {transitions} — switch every 10-30s, aim for 20+")
165
+
166
+ if transitions == 1:
167
+ warnings.append("SINGLE BLOCK: you recorded one unfocused + one focused block — "
168
+ "model will learn temporal position, not focus patterns")
169
+
170
+ if warnings:
171
+ print(f"\n ⚠️ WARNINGS ({len(warnings)}):")
172
+ for w in warnings:
173
+ print(f" • {w}")
174
+ print(f"\n Consider re-recording this session.")
175
+ else:
176
+ print(f"\n ✅ All checks passed!")
177
+
178
+ print(f"{'='*50}\n")
179
+ return len(warnings) == 0
180
+
181
+
182
+ # ---------------------------------------------------------------------------
183
+ # Main
184
+ def main():
185
+ parser = argparse.ArgumentParser()
186
+ parser.add_argument("--name", type=str, default="session",
187
+ help="Your name or session ID")
188
+ parser.add_argument("--camera", type=int, default=0,
189
+ help="Camera index")
190
+ parser.add_argument("--duration", type=int, default=600,
191
+ help="Max recording time (seconds, default 10 min)")
192
+ parser.add_argument("--output-dir", type=str,
193
+ default=os.path.join(_PROJECT_ROOT, "collected_data"),
194
+ help="Where to save .npz files")
195
+ args = parser.parse_args()
196
+
197
+ os.makedirs(args.output_dir, exist_ok=True)
198
+
199
+ detector = FaceMeshDetector()
200
+ head_pose = HeadPoseEstimator()
201
+ eye_scorer = EyeBehaviourScorer()
202
+ temporal = TemporalTracker()
203
+
204
+ cap = cv2.VideoCapture(args.camera)
205
+ if not cap.isOpened():
206
+ print("[COLLECT] ERROR: can't open camera")
207
+ return
208
+
209
+ print("[COLLECT] Data Collection Tool")
210
+ print(f"[COLLECT] Session: {args.name}, max {args.duration}s")
211
+ print(f"[COLLECT] Features per frame: {NUM_FEATURES}")
212
+ print("[COLLECT] Controls:")
213
+ print(" 1 = FOCUSED (looking at screen normally)")
214
+ print(" 0 = NOT FOCUSED (phone, away, eyes closed, yawning)")
215
+ print(" p = pause")
216
+ print(" q = save & quit")
217
+ print()
218
+ print("[COLLECT] TIPS for good data:")
219
+ print(" • Switch between 1 and 0 every 10-30 seconds")
220
+ print(" • Aim for 20+ transitions total")
221
+ print(" • Act out varied scenarios: reading, phone, talking, drowsy")
222
+ print(" • Record at least 5 minutes")
223
+ print()
224
+
225
+ features_list = []
226
+ labels_list = []
227
+ label = None # None = paused
228
+ transitions = 0 # count label switches
229
+ prev_label = None
230
+ status = "PAUSED -- press 1 (focused) or 0 (not focused)"
231
+ t_start = time.time()
232
+ prev_time = time.time()
233
+ fps = 0.0
234
+
235
+ try:
236
+ while True:
237
+ elapsed = time.time() - t_start
238
+ if elapsed > args.duration:
239
+ print(f"[COLLECT] Time limit ({args.duration}s)")
240
+ break
241
+
242
+ ret, frame = cap.read()
243
+ if not ret:
244
+ break
245
+
246
+ h, w = frame.shape[:2]
247
+ landmarks = detector.process(frame)
248
+ face_ok = landmarks is not None
249
+
250
+ if face_ok and label is not None:
251
+ vec = extract_features(landmarks, w, h, head_pose, eye_scorer, temporal)
252
+ features_list.append(vec)
253
+ labels_list.append(label)
254
+
255
+ if prev_label is not None and label != prev_label:
256
+ transitions += 1
257
+ prev_label = label
258
+
259
+ now = time.time()
260
+ fps = 0.9 * fps + 0.1 * (1.0 / max(now - prev_time, 1e-6))
261
+ prev_time = now
262
+
263
+ # --- draw UI ---
264
+ n = len(labels_list)
265
+ n1 = sum(1 for x in labels_list if x == 1)
266
+ n0 = n - n1
267
+ remaining = max(0, args.duration - elapsed)
268
+
269
+ bar_color = GREEN if label == 1 else (RED if label == 0 else (80, 80, 80))
270
+ cv2.rectangle(frame, (0, 0), (w, 70), (0, 0, 0), -1)
271
+ cv2.putText(frame, status, (10, 22), FONT, 0.55, bar_color, 2, cv2.LINE_AA)
272
+ cv2.putText(frame, f"Samples: {n} (F:{n1} U:{n0}) Switches: {transitions}",
273
+ (10, 48), FONT, 0.42, WHITE, 1, cv2.LINE_AA)
274
+ cv2.putText(frame, f"FPS:{fps:.0f}", (w - 80, 22), FONT, 0.45, WHITE, 1, cv2.LINE_AA)
275
+ cv2.putText(frame, f"{int(remaining)}s left", (w - 80, 48), FONT, 0.42, YELLOW, 1, cv2.LINE_AA)
276
+
277
+ if n > 0:
278
+ bar_w = min(w - 20, 300)
279
+ bar_x = w - bar_w - 10
280
+ bar_y = 58
281
+ frac = n1 / n
282
+ cv2.rectangle(frame, (bar_x, bar_y), (bar_x + bar_w, bar_y + 8), (40, 40, 40), -1)
283
+ cv2.rectangle(frame, (bar_x, bar_y), (bar_x + int(bar_w * frac), bar_y + 8), GREEN, -1)
284
+ cv2.putText(frame, f"{frac:.0%}F", (bar_x + bar_w + 4, bar_y + 8),
285
+ FONT, 0.3, GRAY, 1, cv2.LINE_AA)
286
+
287
+ if not face_ok:
288
+ cv2.putText(frame, "NO FACE", (w // 2 - 60, h // 2), FONT, 0.7, RED, 2, cv2.LINE_AA)
289
+
290
+ # red dot = recording
291
+ if label is not None and face_ok:
292
+ cv2.circle(frame, (w - 20, 80), 8, RED, -1)
293
+
294
+ # live warnings
295
+ warn_y = h - 35
296
+ if n > 100 and transitions < 3:
297
+ cv2.putText(frame, "! Switch more often (aim for 20+ transitions)",
298
+ (10, warn_y), FONT, 0.38, ORANGE, 1, cv2.LINE_AA)
299
+ warn_y -= 18
300
+ if elapsed > 30 and n > 0:
301
+ bal = n1 / n
302
+ if bal < 0.25 or bal > 0.75:
303
+ cv2.putText(frame, f"! Imbalanced ({bal:.0%} focused) - record more of the other",
304
+ (10, warn_y), FONT, 0.38, ORANGE, 1, cv2.LINE_AA)
305
+ warn_y -= 18
306
+
307
+ cv2.putText(frame, "1:focused 0:unfocused p:pause q:save+quit",
308
+ (10, h - 10), FONT, 0.38, GRAY, 1, cv2.LINE_AA)
309
+
310
+ cv2.imshow("FocusGuard -- Data Collection", frame)
311
+
312
+ key = cv2.waitKey(1) & 0xFF
313
+ if key == ord("1"):
314
+ label = 1
315
+ status = "Recording: FOCUSED"
316
+ print(f"[COLLECT] -> FOCUSED (n={n}, transitions={transitions})")
317
+ elif key == ord("0"):
318
+ label = 0
319
+ status = "Recording: NOT FOCUSED"
320
+ print(f"[COLLECT] -> NOT FOCUSED (n={n}, transitions={transitions})")
321
+ elif key == ord("p"):
322
+ label = None
323
+ status = "PAUSED"
324
+ print(f"[COLLECT] paused (n={n})")
325
+ elif key == ord("q"):
326
+ break
327
+
328
+ finally:
329
+ cap.release()
330
+ cv2.destroyAllWindows()
331
+ detector.close()
332
+
333
+ if len(features_list) > 0:
334
+ feats = np.stack(features_list)
335
+ labs = np.array(labels_list, dtype=np.int64)
336
+
337
+ ts = time.strftime("%Y%m%d_%H%M%S")
338
+ fname = f"{args.name}_{ts}.npz"
339
+ fpath = os.path.join(args.output_dir, fname)
340
+ np.savez(fpath,
341
+ features=feats,
342
+ labels=labs,
343
+ feature_names=np.array(FEATURE_NAMES))
344
+
345
+ print(f"\n[COLLECT] Saved {len(labs)} samples -> {fpath}")
346
+ print(f" Shape: {feats.shape} ({NUM_FEATURES} features)")
347
+
348
+ quality_report(labs)
349
+ else:
350
+ print("\n[COLLECT] No data collected")
351
+
352
+ print("[COLLECT] Done")
353
+
354
+
355
+ if __name__ == "__main__":
356
+ main()
models/eye_classifier.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from abc import ABC, abstractmethod
4
+
5
+ import numpy as np
6
+
7
+
8
+ class EyeClassifier(ABC):
9
+ @property
10
+ @abstractmethod
11
+ def name(self) -> str:
12
+ pass
13
+
14
+ @abstractmethod
15
+ def predict_score(self, crops_bgr: list[np.ndarray]) -> float:
16
+ pass
17
+
18
+
19
+ class GeometricOnlyClassifier(EyeClassifier):
20
+ @property
21
+ def name(self) -> str:
22
+ return "geometric"
23
+
24
+ def predict_score(self, crops_bgr: list[np.ndarray]) -> float:
25
+ return 1.0
26
+
27
+
28
+ class YOLOv11Classifier(EyeClassifier):
29
+ def __init__(self, checkpoint_path: str, device: str = "cpu"):
30
+ from ultralytics import YOLO
31
+
32
+ self._model = YOLO(checkpoint_path)
33
+ self._device = device
34
+
35
+ names = self._model.names
36
+ self._attentive_idx = None
37
+ for idx, cls_name in names.items():
38
+ if cls_name in ("open", "attentive"):
39
+ self._attentive_idx = idx
40
+ break
41
+ if self._attentive_idx is None:
42
+ self._attentive_idx = max(names.keys())
43
+ print(f"[YOLO] Classes: {names}, attentive_idx={self._attentive_idx}")
44
+
45
+ @property
46
+ def name(self) -> str:
47
+ return "yolo"
48
+
49
+ def predict_score(self, crops_bgr: list[np.ndarray]) -> float:
50
+ if not crops_bgr:
51
+ return 1.0
52
+ results = self._model.predict(crops_bgr, device=self._device, verbose=False)
53
+ scores = [float(r.probs.data[self._attentive_idx]) for r in results]
54
+ return sum(scores) / len(scores) if scores else 1.0
55
+
56
+
57
+ def load_eye_classifier(
58
+ path: str | None = None,
59
+ backend: str = "yolo",
60
+ device: str = "cpu",
61
+ ) -> EyeClassifier:
62
+ if path is None or backend == "geometric":
63
+ return GeometricOnlyClassifier()
64
+
65
+ try:
66
+ return YOLOv11Classifier(path, device=device)
67
+ except ImportError:
68
+ print("[CLASSIFIER] ultralytics required for YOLO. pip install ultralytics")
69
+ raise
models/eye_crop.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+
4
+ from models.face_mesh import FaceMeshDetector
5
+
6
+ LEFT_EYE_CONTOUR = FaceMeshDetector.LEFT_EYE_INDICES
7
+ RIGHT_EYE_CONTOUR = FaceMeshDetector.RIGHT_EYE_INDICES
8
+
9
+ IMAGENET_MEAN = (0.485, 0.456, 0.406)
10
+ IMAGENET_STD = (0.229, 0.224, 0.225)
11
+
12
+ CROP_SIZE = 96
13
+
14
+
15
+ def _bbox_from_landmarks(
16
+ landmarks: np.ndarray,
17
+ indices: list[int],
18
+ frame_w: int,
19
+ frame_h: int,
20
+ expand: float = 0.4,
21
+ ) -> tuple[int, int, int, int]:
22
+ pts = landmarks[indices, :2]
23
+ px = pts[:, 0] * frame_w
24
+ py = pts[:, 1] * frame_h
25
+
26
+ x_min, x_max = px.min(), px.max()
27
+ y_min, y_max = py.min(), py.max()
28
+ w = x_max - x_min
29
+ h = y_max - y_min
30
+ cx = (x_min + x_max) / 2
31
+ cy = (y_min + y_max) / 2
32
+
33
+ size = max(w, h) * (1 + expand)
34
+ half = size / 2
35
+
36
+ x1 = int(max(cx - half, 0))
37
+ y1 = int(max(cy - half, 0))
38
+ x2 = int(min(cx + half, frame_w))
39
+ y2 = int(min(cy + half, frame_h))
40
+
41
+ return x1, y1, x2, y2
42
+
43
+
44
+ def extract_eye_crops(
45
+ frame: np.ndarray,
46
+ landmarks: np.ndarray,
47
+ expand: float = 0.4,
48
+ crop_size: int = CROP_SIZE,
49
+ ) -> tuple[np.ndarray, np.ndarray, tuple, tuple]:
50
+ h, w = frame.shape[:2]
51
+
52
+ left_bbox = _bbox_from_landmarks(landmarks, LEFT_EYE_CONTOUR, w, h, expand)
53
+ right_bbox = _bbox_from_landmarks(landmarks, RIGHT_EYE_CONTOUR, w, h, expand)
54
+
55
+ left_crop = frame[left_bbox[1] : left_bbox[3], left_bbox[0] : left_bbox[2]]
56
+ right_crop = frame[right_bbox[1] : right_bbox[3], right_bbox[0] : right_bbox[2]]
57
+
58
+ if left_crop.size == 0:
59
+ left_crop = np.zeros((crop_size, crop_size, 3), dtype=np.uint8)
60
+ else:
61
+ left_crop = cv2.resize(left_crop, (crop_size, crop_size), interpolation=cv2.INTER_AREA)
62
+
63
+ if right_crop.size == 0:
64
+ right_crop = np.zeros((crop_size, crop_size, 3), dtype=np.uint8)
65
+ else:
66
+ right_crop = cv2.resize(right_crop, (crop_size, crop_size), interpolation=cv2.INTER_AREA)
67
+
68
+ return left_crop, right_crop, left_bbox, right_bbox
69
+
70
+
71
+ def crop_to_tensor(crop_bgr: np.ndarray):
72
+ import torch
73
+
74
+ rgb = cv2.cvtColor(crop_bgr, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
75
+ for c in range(3):
76
+ rgb[:, :, c] = (rgb[:, :, c] - IMAGENET_MEAN[c]) / IMAGENET_STD[c]
77
+ return torch.from_numpy(rgb.transpose(2, 0, 1))
models/eye_scorer.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import numpy as np
4
+
5
+ _LEFT_EYE_EAR = [33, 160, 158, 133, 153, 145]
6
+ _RIGHT_EYE_EAR = [362, 385, 387, 263, 373, 380]
7
+
8
+ _LEFT_IRIS_CENTER = 468
9
+ _RIGHT_IRIS_CENTER = 473
10
+
11
+ _LEFT_EYE_INNER = 133
12
+ _LEFT_EYE_OUTER = 33
13
+ _RIGHT_EYE_INNER = 362
14
+ _RIGHT_EYE_OUTER = 263
15
+
16
+ _LEFT_EYE_TOP = 159
17
+ _LEFT_EYE_BOTTOM = 145
18
+ _RIGHT_EYE_TOP = 386
19
+ _RIGHT_EYE_BOTTOM = 374
20
+
21
+ _MOUTH_TOP = 13
22
+ _MOUTH_BOTTOM = 14
23
+ _MOUTH_LEFT = 78
24
+ _MOUTH_RIGHT = 308
25
+ _MOUTH_UPPER_1 = 82
26
+ _MOUTH_UPPER_2 = 312
27
+ _MOUTH_LOWER_1 = 87
28
+ _MOUTH_LOWER_2 = 317
29
+
30
+ MAR_YAWN_THRESHOLD = 0.55
31
+
32
+
33
+ def _distance(p1: np.ndarray, p2: np.ndarray) -> float:
34
+ return float(np.linalg.norm(p1 - p2))
35
+
36
+
37
+ def compute_ear(landmarks: np.ndarray, eye_indices: list[int]) -> float:
38
+ p1 = landmarks[eye_indices[0], :2]
39
+ p2 = landmarks[eye_indices[1], :2]
40
+ p3 = landmarks[eye_indices[2], :2]
41
+ p4 = landmarks[eye_indices[3], :2]
42
+ p5 = landmarks[eye_indices[4], :2]
43
+ p6 = landmarks[eye_indices[5], :2]
44
+
45
+ vertical1 = _distance(p2, p6)
46
+ vertical2 = _distance(p3, p5)
47
+ horizontal = _distance(p1, p4)
48
+
49
+ if horizontal < 1e-6:
50
+ return 0.0
51
+
52
+ return (vertical1 + vertical2) / (2.0 * horizontal)
53
+
54
+
55
+ def compute_avg_ear(landmarks: np.ndarray) -> float:
56
+ left_ear = compute_ear(landmarks, _LEFT_EYE_EAR)
57
+ right_ear = compute_ear(landmarks, _RIGHT_EYE_EAR)
58
+ return (left_ear + right_ear) / 2.0
59
+
60
+
61
+ def compute_gaze_ratio(landmarks: np.ndarray) -> tuple[float, float]:
62
+ left_iris = landmarks[_LEFT_IRIS_CENTER, :2]
63
+ left_inner = landmarks[_LEFT_EYE_INNER, :2]
64
+ left_outer = landmarks[_LEFT_EYE_OUTER, :2]
65
+ left_top = landmarks[_LEFT_EYE_TOP, :2]
66
+ left_bottom = landmarks[_LEFT_EYE_BOTTOM, :2]
67
+
68
+ right_iris = landmarks[_RIGHT_IRIS_CENTER, :2]
69
+ right_inner = landmarks[_RIGHT_EYE_INNER, :2]
70
+ right_outer = landmarks[_RIGHT_EYE_OUTER, :2]
71
+ right_top = landmarks[_RIGHT_EYE_TOP, :2]
72
+ right_bottom = landmarks[_RIGHT_EYE_BOTTOM, :2]
73
+
74
+ left_h_total = _distance(left_inner, left_outer)
75
+ right_h_total = _distance(right_inner, right_outer)
76
+
77
+ if left_h_total < 1e-6 or right_h_total < 1e-6:
78
+ return 0.5, 0.5
79
+
80
+ left_h_ratio = _distance(left_outer, left_iris) / left_h_total
81
+ right_h_ratio = _distance(right_outer, right_iris) / right_h_total
82
+ h_ratio = (left_h_ratio + right_h_ratio) / 2.0
83
+
84
+ left_v_total = _distance(left_top, left_bottom)
85
+ right_v_total = _distance(right_top, right_bottom)
86
+
87
+ if left_v_total < 1e-6 or right_v_total < 1e-6:
88
+ return h_ratio, 0.5
89
+
90
+ left_v_ratio = _distance(left_top, left_iris) / left_v_total
91
+ right_v_ratio = _distance(right_top, right_iris) / right_v_total
92
+ v_ratio = (left_v_ratio + right_v_ratio) / 2.0
93
+
94
+ return float(np.clip(h_ratio, 0, 1)), float(np.clip(v_ratio, 0, 1))
95
+
96
+
97
+ def compute_mar(landmarks: np.ndarray) -> float:
98
+ top = landmarks[_MOUTH_TOP, :2]
99
+ bottom = landmarks[_MOUTH_BOTTOM, :2]
100
+ left = landmarks[_MOUTH_LEFT, :2]
101
+ right = landmarks[_MOUTH_RIGHT, :2]
102
+ upper1 = landmarks[_MOUTH_UPPER_1, :2]
103
+ lower1 = landmarks[_MOUTH_LOWER_1, :2]
104
+ upper2 = landmarks[_MOUTH_UPPER_2, :2]
105
+ lower2 = landmarks[_MOUTH_LOWER_2, :2]
106
+
107
+ horizontal = _distance(left, right)
108
+ if horizontal < 1e-6:
109
+ return 0.0
110
+ v1 = _distance(upper1, lower1)
111
+ v2 = _distance(top, bottom)
112
+ v3 = _distance(upper2, lower2)
113
+ return (v1 + v2 + v3) / (2.0 * horizontal)
114
+
115
+
116
+ class EyeBehaviourScorer:
117
+ def __init__(
118
+ self,
119
+ ear_open: float = 0.30,
120
+ ear_closed: float = 0.16,
121
+ gaze_max_offset: float = 0.28,
122
+ ):
123
+ self.ear_open = ear_open
124
+ self.ear_closed = ear_closed
125
+ self.gaze_max_offset = gaze_max_offset
126
+
127
+ def _ear_score(self, ear: float) -> float:
128
+ if ear >= self.ear_open:
129
+ return 1.0
130
+ if ear <= self.ear_closed:
131
+ return 0.0
132
+ return (ear - self.ear_closed) / (self.ear_open - self.ear_closed)
133
+
134
+ def _gaze_score(self, h_ratio: float, v_ratio: float) -> float:
135
+ h_offset = abs(h_ratio - 0.5)
136
+ v_offset = abs(v_ratio - 0.5)
137
+ offset = math.sqrt(h_offset**2 + v_offset**2)
138
+ t = min(offset / self.gaze_max_offset, 1.0)
139
+ return 0.5 * (1.0 + math.cos(math.pi * t))
140
+
141
+ def score(self, landmarks: np.ndarray) -> float:
142
+ left_ear = compute_ear(landmarks, _LEFT_EYE_EAR)
143
+ right_ear = compute_ear(landmarks, _RIGHT_EYE_EAR)
144
+ # Use minimum EAR so closing ONE eye is enough to drop the score
145
+ ear = min(left_ear, right_ear)
146
+ ear_s = self._ear_score(ear)
147
+ if ear_s < 0.3:
148
+ return ear_s
149
+ h_ratio, v_ratio = compute_gaze_ratio(landmarks)
150
+ gaze_s = self._gaze_score(h_ratio, v_ratio)
151
+ return ear_s * gaze_s
152
+
153
+ def detailed_score(self, landmarks: np.ndarray) -> dict:
154
+ left_ear = compute_ear(landmarks, _LEFT_EYE_EAR)
155
+ right_ear = compute_ear(landmarks, _RIGHT_EYE_EAR)
156
+ ear = min(left_ear, right_ear)
157
+ ear_s = self._ear_score(ear)
158
+ h_ratio, v_ratio = compute_gaze_ratio(landmarks)
159
+ gaze_s = self._gaze_score(h_ratio, v_ratio)
160
+ s_eye = ear_s if ear_s < 0.3 else ear_s * gaze_s
161
+ return {
162
+ "ear": round(ear, 4),
163
+ "ear_score": round(ear_s, 4),
164
+ "h_gaze": round(h_ratio, 4),
165
+ "v_gaze": round(v_ratio, 4),
166
+ "gaze_score": round(gaze_s, 4),
167
+ "s_eye": round(s_eye, 4),
168
+ }
models/face_mesh.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ from pathlib import Path
4
+ from urllib.request import urlretrieve
5
+
6
+ import cv2
7
+ import numpy as np
8
+ import mediapipe as mp
9
+ from mediapipe.tasks.python.vision import FaceLandmarkerOptions, FaceLandmarker, RunningMode
10
+ from mediapipe.tasks import python as mp_tasks
11
+
12
+ _MODEL_URL = (
13
+ "https://storage.googleapis.com/mediapipe-models/face_landmarker/"
14
+ "face_landmarker/float16/latest/face_landmarker.task"
15
+ )
16
+
17
+
18
+ def _ensure_model() -> str:
19
+ cache_dir = Path(os.environ.get(
20
+ "FOCUSGUARD_CACHE_DIR",
21
+ Path.home() / ".cache" / "focusguard",
22
+ ))
23
+ model_path = cache_dir / "face_landmarker.task"
24
+ if model_path.exists():
25
+ return str(model_path)
26
+ cache_dir.mkdir(parents=True, exist_ok=True)
27
+ print(f"[FACE_MESH] Downloading model to {model_path}...")
28
+ urlretrieve(_MODEL_URL, model_path)
29
+ print("[FACE_MESH] Download complete.")
30
+ return str(model_path)
31
+
32
+
33
+ class FaceMeshDetector:
34
+ LEFT_EYE_INDICES = [33, 7, 163, 144, 145, 153, 154, 155, 133, 173, 157, 158, 159, 160, 161, 246]
35
+ RIGHT_EYE_INDICES = [362, 382, 381, 380, 374, 373, 390, 249, 263, 466, 388, 387, 386, 385, 384, 398]
36
+ LEFT_IRIS_INDICES = [468, 469, 470, 471, 472]
37
+ RIGHT_IRIS_INDICES = [473, 474, 475, 476, 477]
38
+
39
+ def __init__(
40
+ self,
41
+ max_num_faces: int = 1,
42
+ min_detection_confidence: float = 0.5,
43
+ min_tracking_confidence: float = 0.5,
44
+ ):
45
+ model_path = _ensure_model()
46
+ options = FaceLandmarkerOptions(
47
+ base_options=mp_tasks.BaseOptions(model_asset_path=model_path),
48
+ num_faces=max_num_faces,
49
+ min_face_detection_confidence=min_detection_confidence,
50
+ min_face_presence_confidence=min_detection_confidence,
51
+ min_tracking_confidence=min_tracking_confidence,
52
+ running_mode=RunningMode.VIDEO,
53
+ )
54
+ self._landmarker = FaceLandmarker.create_from_options(options)
55
+ self._t0 = time.monotonic()
56
+ self._last_ts = 0
57
+
58
+ def process(self, bgr_frame: np.ndarray) -> np.ndarray | None:
59
+ # BGR in -> (478,3) norm x,y,z or None
60
+ rgb = cv2.cvtColor(bgr_frame, cv2.COLOR_BGR2RGB)
61
+ mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=rgb)
62
+ ts = max(int((time.monotonic() - self._t0) * 1000), self._last_ts + 1)
63
+ self._last_ts = ts
64
+ result = self._landmarker.detect_for_video(mp_image, ts)
65
+
66
+ if not result.face_landmarks:
67
+ return None
68
+
69
+ face = result.face_landmarks[0]
70
+ return np.array([(lm.x, lm.y, lm.z) for lm in face], dtype=np.float32)
71
+
72
+ def get_pixel_landmarks(self, landmarks: np.ndarray, frame_w: int, frame_h: int) -> np.ndarray:
73
+ # norm -> pixel (x,y)
74
+ pixel = np.zeros((landmarks.shape[0], 2), dtype=np.int32)
75
+ pixel[:, 0] = (landmarks[:, 0] * frame_w).astype(np.int32)
76
+ pixel[:, 1] = (landmarks[:, 1] * frame_h).astype(np.int32)
77
+ return pixel
78
+
79
+ def get_3d_landmarks(self, landmarks: np.ndarray, frame_w: int, frame_h: int) -> np.ndarray:
80
+ # norm -> pixel-scale x,y,z (z scaled by width)
81
+ pts = np.zeros_like(landmarks)
82
+ pts[:, 0] = landmarks[:, 0] * frame_w
83
+ pts[:, 1] = landmarks[:, 1] * frame_h
84
+ pts[:, 2] = landmarks[:, 2] * frame_w
85
+ return pts
86
+
87
+ def close(self):
88
+ self._landmarker.close()
89
+
90
+ def __enter__(self):
91
+ return self
92
+
93
+ def __exit__(self, *args):
94
+ self.close()
models/head_pose.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import cv2
4
+ import numpy as np
5
+
6
+ _LANDMARK_INDICES = [1, 152, 33, 263, 61, 291]
7
+
8
+ _MODEL_POINTS = np.array(
9
+ [
10
+ [0.0, 0.0, 0.0],
11
+ [0.0, -330.0, -65.0],
12
+ [-225.0, 170.0, -135.0],
13
+ [225.0, 170.0, -135.0],
14
+ [-150.0, -150.0, -125.0],
15
+ [150.0, -150.0, -125.0],
16
+ ],
17
+ dtype=np.float64,
18
+ )
19
+
20
+
21
+ class HeadPoseEstimator:
22
+ def __init__(self, max_angle: float = 30.0, roll_weight: float = 0.5):
23
+ self.max_angle = max_angle
24
+ self.roll_weight = roll_weight
25
+ self._camera_matrix = None
26
+ self._frame_size = None
27
+ self._dist_coeffs = np.zeros((4, 1), dtype=np.float64)
28
+ self._cache_key = None
29
+ self._cache_result = None
30
+
31
+ def _get_camera_matrix(self, frame_w: int, frame_h: int) -> np.ndarray:
32
+ if self._camera_matrix is not None and self._frame_size == (frame_w, frame_h):
33
+ return self._camera_matrix
34
+ focal_length = float(frame_w)
35
+ cx, cy = frame_w / 2.0, frame_h / 2.0
36
+ self._camera_matrix = np.array(
37
+ [[focal_length, 0, cx], [0, focal_length, cy], [0, 0, 1]],
38
+ dtype=np.float64,
39
+ )
40
+ self._frame_size = (frame_w, frame_h)
41
+ return self._camera_matrix
42
+
43
+ def _solve(self, landmarks: np.ndarray, frame_w: int, frame_h: int):
44
+ key = (landmarks.data.tobytes(), frame_w, frame_h)
45
+ if self._cache_key == key:
46
+ return self._cache_result
47
+
48
+ image_points = np.array(
49
+ [
50
+ [landmarks[i, 0] * frame_w, landmarks[i, 1] * frame_h]
51
+ for i in _LANDMARK_INDICES
52
+ ],
53
+ dtype=np.float64,
54
+ )
55
+ camera_matrix = self._get_camera_matrix(frame_w, frame_h)
56
+ success, rvec, tvec = cv2.solvePnP(
57
+ _MODEL_POINTS,
58
+ image_points,
59
+ camera_matrix,
60
+ self._dist_coeffs,
61
+ flags=cv2.SOLVEPNP_ITERATIVE,
62
+ )
63
+ result = (success, rvec, tvec, image_points)
64
+ self._cache_key = key
65
+ self._cache_result = result
66
+ return result
67
+
68
+ def estimate(
69
+ self, landmarks: np.ndarray, frame_w: int, frame_h: int
70
+ ) -> tuple[float, float, float] | None:
71
+ success, rvec, tvec, _ = self._solve(landmarks, frame_w, frame_h)
72
+ if not success:
73
+ return None
74
+
75
+ rmat, _ = cv2.Rodrigues(rvec)
76
+ nose_dir = rmat @ np.array([0.0, 0.0, 1.0])
77
+ face_up = rmat @ np.array([0.0, 1.0, 0.0])
78
+
79
+ yaw = math.degrees(math.atan2(nose_dir[0], -nose_dir[2]))
80
+ pitch = math.degrees(math.asin(np.clip(-nose_dir[1], -1.0, 1.0)))
81
+ roll = math.degrees(math.atan2(face_up[0], -face_up[1]))
82
+
83
+ return (yaw, pitch, roll)
84
+
85
+ def score(self, landmarks: np.ndarray, frame_w: int, frame_h: int) -> float:
86
+ angles = self.estimate(landmarks, frame_w, frame_h)
87
+ if angles is None:
88
+ return 0.0
89
+
90
+ yaw, pitch, roll = angles
91
+ deviation = math.sqrt(yaw**2 + pitch**2 + (self.roll_weight * roll) ** 2)
92
+ t = min(deviation / self.max_angle, 1.0)
93
+ return 0.5 * (1.0 + math.cos(math.pi * t))
94
+
95
+ def draw_axes(
96
+ self,
97
+ frame: np.ndarray,
98
+ landmarks: np.ndarray,
99
+ axis_length: float = 50.0,
100
+ ) -> np.ndarray:
101
+ h, w = frame.shape[:2]
102
+ success, rvec, tvec, image_points = self._solve(landmarks, w, h)
103
+ if not success:
104
+ return frame
105
+
106
+ camera_matrix = self._get_camera_matrix(w, h)
107
+ nose = tuple(image_points[0].astype(int))
108
+
109
+ axes_3d = np.float64(
110
+ [[axis_length, 0, 0], [0, axis_length, 0], [0, 0, axis_length]]
111
+ )
112
+ projected, _ = cv2.projectPoints(
113
+ axes_3d, rvec, tvec, camera_matrix, self._dist_coeffs
114
+ )
115
+
116
+ colors = [(0, 0, 255), (0, 255, 0), (255, 0, 0)]
117
+ for i, color in enumerate(colors):
118
+ pt = tuple(projected[i].ravel().astype(int))
119
+ cv2.line(frame, nose, pt, color, 2)
120
+
121
+ return frame
package-lock.json ADDED
The diff for this file is too large to render. See raw diff
 
package.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "my-ai-app",
3
+ "private": true,
4
+ "version": "0.0.0",
5
+ "type": "module",
6
+ "scripts": {
7
+ "dev": "vite",
8
+ "build": "vite build",
9
+ "lint": "eslint .",
10
+ "preview": "vite preview"
11
+ },
12
+ "dependencies": {
13
+ "react": "^19.2.0",
14
+ "react-dom": "^19.2.0"
15
+ },
16
+ "devDependencies": {
17
+ "@eslint/js": "^9.39.1",
18
+ "@types/react": "^19.2.5",
19
+ "@types/react-dom": "^19.2.3",
20
+ "@vitejs/plugin-react": "^5.1.1",
21
+ "eslint": "^9.39.1",
22
+ "eslint-plugin-react-hooks": "^7.0.1",
23
+ "eslint-plugin-react-refresh": "^0.4.24",
24
+ "globals": "^16.5.0",
25
+ "vite": "^7.2.4"
26
+ }
27
+ }
public/assets/111.jpg ADDED
public/vite.svg ADDED
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ websockets
2
+ python-multipart
3
+ jinja2
4
+ aiortc
5
+ av
6
+ mediapipe>=0.10.14
7
+ opencv-python-headless>=4.8.0
8
+ numpy>=1.24.0
9
+ scikit-learn>=1.2.0
10
+ joblib>=1.2.0
11
+ fastapi>=0.104.0
12
+ uvicorn[standard]>=0.24.0
13
+ aiosqlite>=0.19.0
14
+ pydantic>=2.0.0
src/App.css ADDED
@@ -0,0 +1,799 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* =========================================
2
+ 1. REACT layout setting
3
+ ========================================= */
4
+ html, body, #root {
5
+ width: 100%;
6
+ height: 100%;
7
+ margin: 0;
8
+ padding: 0;
9
+ }
10
+
11
+ .app-container {
12
+ width: 100%;
13
+ min-height: 100vh; /* screen height */
14
+ display: flex;
15
+ flex-direction: column;
16
+ background-color: #f9f9f9;
17
+ }
18
+
19
+ /* =========================================
20
+ 2. original layout
21
+ ========================================= */
22
+
23
+ /* GLOBAL STYLES */
24
+ body {
25
+ font-family: 'Nunito', sans-serif;
26
+ background-color: #f9f9f9;
27
+ overflow-x: hidden;
28
+ overflow-y: auto;
29
+ }
30
+
31
+ /* dynamic class name */
32
+ .hidden {
33
+ display: none !important;
34
+ }
35
+
36
+ /* TOP MENU */
37
+ #top-menu {
38
+ height: 60px;
39
+ background-color: white;
40
+ display: flex;
41
+ align-items: center;
42
+ justify-content: center; /* Center buttons horizontally */
43
+ box-shadow: 0 2px 5px rgba(0,0,0,0.05);
44
+ position: fixed;
45
+ top: 0;
46
+ width: 100%;
47
+ z-index: 1000;
48
+ }
49
+
50
+ .menu-btn {
51
+ background: none;
52
+ border: none;
53
+ font-family: 'Nunito', sans-serif;
54
+ font-size: 16px;
55
+ color: #333;
56
+ padding: 10px 20px;
57
+ cursor: pointer;
58
+ transition: background-color 0.2s;
59
+ }
60
+
61
+ .menu-btn:hover {
62
+ background-color: #f0f0f0;
63
+ border-radius: 4px;
64
+ }
65
+
66
+ /* active for React */
67
+ .menu-btn.active {
68
+ font-weight: bold;
69
+ color: #007BFF;
70
+ background-color: #eef7ff;
71
+ border-radius: 4px;
72
+ }
73
+
74
+ .separator {
75
+ width: 1px;
76
+ height: 20px;
77
+ background-color: #555; /* Dark gray separator */
78
+ margin: 0 5px;
79
+ }
80
+
81
+ /* PAGE CONTAINER */
82
+ .page {
83
+ /* content under menu */
84
+ min-height: calc(100vh - 60px);
85
+ width: 100%;
86
+ padding-top: 60px; /* Space for fixed menu */
87
+ padding-bottom: 40px; /* Space at bottom for scrolling */
88
+ box-sizing: border-box;
89
+ display: flex;
90
+ flex-direction: column;
91
+ align-items: center;
92
+ overflow-y: auto;
93
+ }
94
+
95
+ /* Ensure page titles are black */
96
+ .page h1 {
97
+ color: #000 !important;
98
+ background: transparent !important;
99
+ }
100
+
101
+ .page-title {
102
+ color: #000 !important;
103
+ background: transparent !important;
104
+ }
105
+
106
+ /* PAGE A SPECIFIC */
107
+ #page-a {
108
+ justify-content: center; /* Center vertically */
109
+ /* 注意:因为 React 结构变化,如果感觉偏下,可以微调这个 margin-top */
110
+ margin-top: -40px;
111
+ flex: 1; /* 确保它占满剩余空间以便垂直居中 */
112
+ }
113
+
114
+ #page-a h1 {
115
+ font-size: 80px;
116
+ margin: 0 0 10px 0;
117
+ color: #000;
118
+ text-align: center; /* 确保文字居中 */
119
+ }
120
+
121
+ #page-a p {
122
+ color: #666;
123
+ font-size: 20px;
124
+ margin-bottom: 40px;
125
+ text-align: center;
126
+ }
127
+
128
+ .btn-main {
129
+ background-color: #007BFF; /* Blue */
130
+ color: white;
131
+ border: none;
132
+ padding: 15px 50px;
133
+ font-size: 20px;
134
+ font-family: 'Nunito', sans-serif;
135
+ border-radius: 30px; /* Fully rounded corners */
136
+ cursor: pointer;
137
+ transition: transform 0.2s ease;
138
+ }
139
+
140
+ .btn-main:hover {
141
+ transform: scale(1.1); /* Zoom effect */
142
+ }
143
+
144
+ /* PAGE B SPECIFIC */
145
+ #page-b {
146
+ justify-content: space-evenly; /* Distribute vertical space */
147
+ padding-bottom: 20px;
148
+ min-height: calc(100vh - 60px); /* 再次确保高度足够 */
149
+ }
150
+
151
+ /* 1. Display Area */
152
+ #display-area {
153
+ width: 60%;
154
+ height: 50vh; /* 改用 vh 单位,确保在不同屏幕下的高度比例 */
155
+ min-height: 300px;
156
+ border: 2px solid #ddd;
157
+ border-radius: 12px;
158
+ background-color: #fff;
159
+ display: flex;
160
+ align-items: center;
161
+ justify-content: center;
162
+ color: #555;
163
+ font-size: 24px;
164
+ position: relative;
165
+ /* 确保视频元素也能居中且不溢出 */
166
+ overflow: hidden;
167
+ }
168
+
169
+ #display-area video {
170
+ width: 100%;
171
+ height: 100%;
172
+ object-fit: cover; /* 类似于 background-size: cover */
173
+ }
174
+
175
+ /* 2. Timeline Area */
176
+ #timeline-area {
177
+ width: 60%;
178
+ height: 80px;
179
+ position: relative;
180
+ display: flex;
181
+ flex-direction: column;
182
+ justify-content: flex-end;
183
+ }
184
+
185
+ .timeline-label {
186
+ position: absolute;
187
+ top: 0;
188
+ left: 0;
189
+ color: #888;
190
+ font-size: 14px;
191
+ }
192
+
193
+ #timeline-line {
194
+ width: 100%;
195
+ height: 2px;
196
+ background-color: #87CEEB; /* Light blue */
197
+ }
198
+
199
+ /* 3. Control Panel */
200
+ #control-panel {
201
+ display: flex;
202
+ gap: 20px;
203
+ width: 60%;
204
+ justify-content: space-between;
205
+ }
206
+
207
+ .action-btn {
208
+ flex: 1; /* Evenly distributed width */
209
+ padding: 12px 0;
210
+ border: none;
211
+ border-radius: 12px;
212
+ font-size: 16px;
213
+ font-family: 'Nunito', sans-serif;
214
+ font-weight: 700;
215
+ cursor: pointer;
216
+ color: white;
217
+ transition: opacity 0.2s;
218
+ }
219
+
220
+ .action-btn:hover {
221
+ opacity: 0.9;
222
+ }
223
+
224
+ .action-btn.green { background-color: #28a745; }
225
+ .action-btn.yellow { background-color: #ffce0b; }
226
+ .action-btn.blue { background-color: #326ed6; }
227
+ .action-btn.red { background-color: #dc3545; }
228
+
229
+ /* 4. Frame Control */
230
+ #frame-control {
231
+ display: flex;
232
+ align-items: center;
233
+ gap: 15px;
234
+ color: #333;
235
+ font-weight: bold;
236
+ }
237
+
238
+ #frame-slider {
239
+ width: 200px;
240
+ cursor: pointer;
241
+ }
242
+
243
+ #frame-input {
244
+ width: 50px;
245
+ padding: 5px;
246
+ border: 1px solid #ccc;
247
+ border-radius: 5px;
248
+ text-align: center;
249
+ font-family: 'Nunito', sans-serif;
250
+ }
251
+
252
+ /* ================ ACHIEVEMENT PAGE ================ */
253
+
254
+ .stats-grid {
255
+ display: grid;
256
+ grid-template-columns: repeat(4, 1fr);
257
+ gap: 20px;
258
+ width: 80%;
259
+ margin: 40px auto;
260
+ }
261
+
262
+ .stat-card {
263
+ background: white;
264
+ padding: 30px;
265
+ border-radius: 12px;
266
+ text-align: center;
267
+ box-shadow: 0 2px 10px rgba(0,0,0,0.1);
268
+ }
269
+
270
+ .stat-number {
271
+ font-size: 48px;
272
+ font-weight: bold;
273
+ color: #007BFF;
274
+ margin-bottom: 10px;
275
+ }
276
+
277
+ .stat-label {
278
+ font-size: 16px;
279
+ color: #666;
280
+ }
281
+
282
+ .achievements-section {
283
+ width: 80%;
284
+ margin: 0 auto;
285
+ }
286
+
287
+ .achievements-section h2 {
288
+ color: #333;
289
+ margin-bottom: 20px;
290
+ }
291
+
292
+ .badges-grid {
293
+ display: grid;
294
+ grid-template-columns: repeat(3, 1fr);
295
+ gap: 20px;
296
+ }
297
+
298
+ .badge {
299
+ background: white;
300
+ padding: 30px 20px;
301
+ border-radius: 12px;
302
+ text-align: center;
303
+ box-shadow: 0 2px 10px rgba(0,0,0,0.1);
304
+ transition: transform 0.2s;
305
+ }
306
+
307
+ .badge:hover {
308
+ transform: translateY(-5px);
309
+ }
310
+
311
+ .badge.locked {
312
+ opacity: 0.4;
313
+ filter: grayscale(100%);
314
+ }
315
+
316
+ .badge-icon {
317
+ font-size: 64px;
318
+ margin-bottom: 15px;
319
+ }
320
+
321
+ .badge-name {
322
+ font-size: 16px;
323
+ font-weight: bold;
324
+ color: #333;
325
+ }
326
+
327
+ /* ================ RECORDS PAGE ================ */
328
+
329
+ .records-controls {
330
+ display: flex;
331
+ gap: 10px;
332
+ margin: 20px auto;
333
+ width: 80%;
334
+ justify-content: center;
335
+ }
336
+
337
+ .filter-btn {
338
+ padding: 10px 20px;
339
+ border: 2px solid #007BFF;
340
+ background: white;
341
+ color: #007BFF;
342
+ border-radius: 8px;
343
+ cursor: pointer;
344
+ font-family: 'Nunito', sans-serif;
345
+ font-weight: 600;
346
+ transition: all 0.2s;
347
+ }
348
+
349
+ .filter-btn:hover {
350
+ background: #e7f3ff;
351
+ }
352
+
353
+ .filter-btn.active {
354
+ background: #007BFF;
355
+ color: white;
356
+ }
357
+
358
+ .chart-container {
359
+ width: 80%;
360
+ background: white;
361
+ padding: 30px;
362
+ border-radius: 12px;
363
+ margin: 20px auto;
364
+ box-shadow: 0 2px 10px rgba(0,0,0,0.1);
365
+ }
366
+
367
+ #focus-chart {
368
+ display: block;
369
+ margin: 0 auto;
370
+ /* 确保图表在容器内自适应 */
371
+ max-width: 100%;
372
+ }
373
+
374
+ .sessions-list {
375
+ width: 80%;
376
+ margin: 20px auto;
377
+ }
378
+
379
+ .sessions-list h2 {
380
+ color: #333;
381
+ margin-bottom: 15px;
382
+ }
383
+
384
+ #sessions-table {
385
+ width: 100%;
386
+ background: white;
387
+ border-collapse: collapse;
388
+ border-radius: 12px;
389
+ overflow: hidden;
390
+ box-shadow: 0 2px 10px rgba(0,0,0,0.1);
391
+ }
392
+
393
+ #sessions-table th {
394
+ background: #007BFF;
395
+ color: white;
396
+ padding: 15px;
397
+ text-align: left;
398
+ font-weight: 600;
399
+ }
400
+
401
+ #sessions-table td {
402
+ padding: 12px 15px;
403
+ border-bottom: 1px solid #eee;
404
+ }
405
+
406
+ #sessions-table tr:last-child td {
407
+ border-bottom: none;
408
+ }
409
+
410
+ #sessions-table tbody tr:hover {
411
+ background: #f8f9fa;
412
+ }
413
+
414
+ .btn-view {
415
+ padding: 6px 12px;
416
+ background: #007BFF;
417
+ color: white;
418
+ border: none;
419
+ border-radius: 5px;
420
+ cursor: pointer;
421
+ font-family: 'Nunito', sans-serif;
422
+ transition: background 0.2s;
423
+ }
424
+
425
+ .btn-view:hover {
426
+ background: #0056b3;
427
+ }
428
+
429
+ /* ================ SETTINGS PAGE ================ */
430
+
431
+ .settings-container {
432
+ width: 60%;
433
+ max-width: 800px;
434
+ margin: 20px auto;
435
+ }
436
+
437
+ .setting-group {
438
+ background: white;
439
+ padding: 30px;
440
+ border-radius: 12px;
441
+ margin-bottom: 20px;
442
+ box-shadow: 0 2px 10px rgba(0,0,0,0.1);
443
+ }
444
+
445
+ .setting-group h2 {
446
+ margin-top: 0;
447
+ color: #333;
448
+ font-size: 20px;
449
+ margin-bottom: 20px;
450
+ border-bottom: 2px solid #007BFF;
451
+ padding-bottom: 10px;
452
+ }
453
+
454
+ .setting-item {
455
+ margin-bottom: 25px;
456
+ }
457
+
458
+ .setting-item:last-child {
459
+ margin-bottom: 0;
460
+ }
461
+
462
+ .setting-item label {
463
+ display: block;
464
+ margin-bottom: 8px;
465
+ color: #333;
466
+ font-weight: 600;
467
+ }
468
+
469
+ .slider-group {
470
+ display: flex;
471
+ align-items: center;
472
+ gap: 15px;
473
+ }
474
+
475
+ .slider-group input[type="range"] {
476
+ flex: 1;
477
+ }
478
+
479
+ .slider-group span {
480
+ min-width: 40px;
481
+ text-align: center;
482
+ font-weight: bold;
483
+ color: #007BFF;
484
+ font-size: 18px;
485
+ }
486
+
487
+ .setting-description {
488
+ font-size: 14px;
489
+ color: #666;
490
+ margin-top: 5px;
491
+ font-style: italic;
492
+ }
493
+
494
+ input[type="checkbox"] {
495
+ margin-right: 10px;
496
+ cursor: pointer;
497
+ }
498
+
499
+ input[type="number"] {
500
+ width: 100px;
501
+ padding: 8px;
502
+ border: 1px solid #ccc;
503
+ border-radius: 5px;
504
+ font-family: 'Nunito', sans-serif;
505
+ }
506
+
507
+ /* --- 新的代码:让按钮居中且变宽 --- */
508
+ .setting-group .action-btn {
509
+ display: inline-block; /* 允许并排显示 */
510
+ width: 48%; /* 两个按钮各占约一半宽度 (留2%缝隙) */
511
+ margin: 15px 1%; /* 上下 15px,左右 1% 间距来实现居中和分隔 */
512
+ text-align: center; /* 文字居中 */
513
+ box-sizing: border-box; /* ��保边框不会撑大按钮导致换行 */
514
+ }
515
+
516
+ #save-settings {
517
+ display: block;
518
+ margin: 20px auto;
519
+ }
520
+
521
+ /* ================ HELP PAGE ================ */
522
+
523
+ .help-container {
524
+ width: 70%;
525
+ max-width: 900px;
526
+ margin: 20px auto;
527
+ }
528
+
529
+ /* Fake ad block (Help page) */
530
+ .fake-ad {
531
+ position: relative;
532
+ display: block;
533
+ width: min(600px, 90%);
534
+ margin: 10px auto 30px auto;
535
+ border: 1px solid #e5e5e5;
536
+ border-radius: 12px;
537
+ overflow: hidden;
538
+ background: #fff;
539
+ text-decoration: none;
540
+ box-shadow: 0 8px 24px rgba(0,0,0,0.12);
541
+ transition: transform 0.2s ease, box-shadow 0.2s ease;
542
+ }
543
+
544
+ .fake-ad:hover {
545
+ transform: translateY(-2px);
546
+ box-shadow: 0 12px 30px rgba(0,0,0,0.16);
547
+ }
548
+
549
+ .fake-ad img {
550
+ display: block;
551
+ width: 100%;
552
+ height: auto;
553
+ }
554
+
555
+ .fake-ad-badge {
556
+ position: absolute;
557
+ top: 12px;
558
+ left: 12px;
559
+ background: rgba(0,0,0,0.75);
560
+ color: #fff;
561
+ font-size: 12px;
562
+ padding: 4px 8px;
563
+ border-radius: 6px;
564
+ letter-spacing: 0.5px;
565
+ }
566
+
567
+ .fake-ad-cta {
568
+ position: absolute;
569
+ right: 12px;
570
+ bottom: 12px;
571
+ background: #111;
572
+ color: #fff;
573
+ font-size: 14px;
574
+ padding: 8px 12px;
575
+ border-radius: 8px;
576
+ }
577
+
578
+ .help-section {
579
+ background: white;
580
+ padding: 30px;
581
+ border-radius: 12px;
582
+ margin-bottom: 20px;
583
+ box-shadow: 0 2px 10px rgba(0,0,0,0.1);
584
+ }
585
+
586
+ .help-section h2 {
587
+ color: #007BFF;
588
+ margin-top: 0;
589
+ margin-bottom: 15px;
590
+ }
591
+
592
+ .help-section ol,
593
+ .help-section ul {
594
+ line-height: 1.8;
595
+ color: #333;
596
+ }
597
+
598
+ .help-section p {
599
+ line-height: 1.6;
600
+ color: #333;
601
+ }
602
+
603
+ details {
604
+ margin: 15px 0;
605
+ cursor: pointer;
606
+ padding: 10px;
607
+ background: #f8f9fa;
608
+ border-radius: 5px;
609
+ }
610
+
611
+ summary {
612
+ font-weight: bold;
613
+ padding: 5px;
614
+ color: #007BFF;
615
+ }
616
+
617
+ details[open] summary {
618
+ margin-bottom: 10px;
619
+ border-bottom: 1px solid #ddd;
620
+ padding-bottom: 10px;
621
+ }
622
+
623
+ details p {
624
+ margin: 10px 0 0 0;
625
+ }
626
+
627
+ /* ================ SESSION SUMMARY MODAL ================ */
628
+ /* 如果将来要做弹窗,这些样式可以直接复用 */
629
+ .modal-overlay {
630
+ position: fixed;
631
+ top: 0;
632
+ left: 0;
633
+ width: 100%;
634
+ height: 100%;
635
+ background: rgba(0, 0, 0, 0.7);
636
+ display: flex;
637
+ align-items: center;
638
+ justify-content: center;
639
+ z-index: 2000;
640
+ }
641
+
642
+ .modal-content {
643
+ background: white;
644
+ padding: 40px;
645
+ border-radius: 16px;
646
+ box-shadow: 0 10px 40px rgba(0,0,0,0.3);
647
+ max-width: 500px;
648
+ width: 90%;
649
+ }
650
+
651
+ .modal-content h2 {
652
+ margin-top: 0;
653
+ color: #333;
654
+ text-align: center;
655
+ margin-bottom: 30px;
656
+ }
657
+
658
+ .summary-stats {
659
+ margin-bottom: 30px;
660
+ }
661
+
662
+ .summary-item {
663
+ display: flex;
664
+ justify-content: space-between;
665
+ padding: 15px 0;
666
+ border-bottom: 1px solid #eee;
667
+ }
668
+
669
+ .summary-item:last-child {
670
+ border-bottom: none;
671
+ }
672
+
673
+ .summary-label {
674
+ font-weight: 600;
675
+ color: #666;
676
+ }
677
+
678
+ .summary-value {
679
+ font-weight: bold;
680
+ color: #007BFF;
681
+ font-size: 18px;
682
+ }
683
+
684
+ .modal-content .btn-main {
685
+ display: block;
686
+ margin: 0 auto;
687
+ padding: 12px 40px;
688
+ }
689
+
690
+ /* ================ TIMELINE BLOCKS ================ */
691
+
692
+ .timeline-block {
693
+ transition: opacity 0.2s;
694
+ border-radius: 2px;
695
+ }
696
+
697
+ .timeline-block:hover {
698
+ opacity: 0.7;
699
+ }
700
+
701
+ /* ================ RESPONSIVE DESIGN ================ */
702
+
703
+ @media (max-width: 1200px) {
704
+ .stats-grid {
705
+ grid-template-columns: repeat(2, 1fr);
706
+ }
707
+
708
+ .badges-grid {
709
+ grid-template-columns: repeat(2, 1fr);
710
+ }
711
+ }
712
+
713
+ @media (max-width: 768px) {
714
+ .stats-grid,
715
+ .badges-grid {
716
+ grid-template-columns: 1fr;
717
+ width: 90%;
718
+ }
719
+
720
+ .settings-container,
721
+ .help-container,
722
+ .chart-container,
723
+ .sessions-list,
724
+ .records-controls {
725
+ width: 90%;
726
+ }
727
+
728
+ #control-panel {
729
+ width: 90%;
730
+ flex-wrap: wrap;
731
+ }
732
+
733
+ #display-area {
734
+ width: 90%;
735
+ }
736
+
737
+ #timeline-area {
738
+ width: 90%;
739
+ }
740
+
741
+ #frame-control {
742
+ width: 90%;
743
+ flex-direction: column;
744
+ }
745
+ }
746
+ /* =========================================
747
+ SESSION RESULT OVERLAY (新增)
748
+ ========================================= */
749
+
750
+ .session-result-overlay {
751
+ position: absolute;
752
+ top: 0;
753
+ left: 0;
754
+ width: 100%;
755
+ height: 100%;
756
+ background-color: rgba(0, 0, 0, 0.85); /* 深色半透明背景 */
757
+ display: flex;
758
+ flex-direction: column;
759
+ justify-content: center;
760
+ align-items: center;
761
+ color: white;
762
+ z-index: 10;
763
+ animation: fadeIn 0.5s ease;
764
+ backdrop-filter: blur(5px); /* 背景模糊效果 (可选) */
765
+ }
766
+
767
+ .session-result-overlay h3 {
768
+ font-size: 32px;
769
+ margin-bottom: 30px;
770
+ color: #4cd137; /* 绿色标题 */
771
+ text-transform: uppercase;
772
+ letter-spacing: 2px;
773
+ }
774
+
775
+ .session-result-overlay .result-item {
776
+ display: flex;
777
+ justify-content: space-between;
778
+ width: 200px; /* 控制宽度 */
779
+ margin-bottom: 15px;
780
+ font-size: 20px;
781
+ border-bottom: 1px solid rgba(255,255,255,0.2);
782
+ padding-bottom: 5px;
783
+ }
784
+
785
+ .session-result-overlay .label {
786
+ color: #ccc;
787
+ font-weight: normal;
788
+ }
789
+
790
+ .session-result-overlay .value {
791
+ color: #fff;
792
+ font-weight: bold;
793
+ font-family: 'Courier New', monospace; /* 看起来像数据 */
794
+ }
795
+
796
+ @keyframes fadeIn {
797
+ from { opacity: 0; transform: scale(0.95); }
798
+ to { opacity: 1; transform: scale(1); }
799
+ }
src/App.jsx ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import React, { useState, useRef, useEffect } from 'react';
2
+ import './App.css';
3
+ import { VideoManagerLocal } from './utils/VideoManagerLocal';
4
+
5
+ // 引入我们刚刚拆分好的组件
6
+ // 注意:确保你的文件名和这里 import 的名字完全一致
7
+ import Home from './components/Home';
8
+ import FocusPageLocal from './components/FocusPageLocal';
9
+ import Achievement from './components/Achievement';
10
+ import Records from './components/Records';
11
+ import Customise from './components/Customise';
12
+ import Help from './components/Help';
13
+
14
+ function App() {
15
+ // 定义状态:当前显示哪个页面
16
+ const [activeTab, setActiveTab] = useState('home');
17
+
18
+ // 全局VideoManagerLocal实例和session状态
19
+ const videoManagerRef = useRef(null);
20
+ const [isSessionActive, setIsSessionActive] = useState(false);
21
+ const [sessionResult, setSessionResult] = useState(null);
22
+
23
+ // 初始化VideoManagerLocal(只创建一次)
24
+ useEffect(() => {
25
+ const callbacks = {
26
+ onSessionStart: () => {
27
+ setIsSessionActive(true);
28
+ setSessionResult(null);
29
+ },
30
+ onSessionEnd: (summary) => {
31
+ console.log("App.jsx: onSessionEnd called");
32
+ console.log("App.jsx: Session Summary:", summary);
33
+ setIsSessionActive(false);
34
+ if (summary) {
35
+ console.log("App.jsx: Setting sessionResult with summary");
36
+ setSessionResult(summary);
37
+ } else {
38
+ console.warn("App.jsx: Summary is null or undefined");
39
+ }
40
+ }
41
+ };
42
+
43
+ videoManagerRef.current = new VideoManagerLocal(callbacks);
44
+
45
+ // 清理函数:只在整个App卸载时才清理
46
+ return () => {
47
+ if (videoManagerRef.current) {
48
+ videoManagerRef.current.stopStreaming();
49
+ }
50
+ };
51
+ }, []);
52
+
53
+ // 页面切换时保持session连接,不自动断开
54
+
55
+ return (
56
+ <div className="app-container">
57
+ {/* 顶部导航栏 (Top Menu) - 它是全局共用的,所以留在 App.jsx 里 */}
58
+ <nav id="top-menu">
59
+ <button
60
+ className={`menu-btn ${activeTab === 'focus' ? 'active' : ''}`}
61
+ onClick={() => setActiveTab('focus')}
62
+ >
63
+ Start Focus {isSessionActive && <span style={{marginLeft: '8px', color: '#00FF00'}}>●</span>}
64
+ </button>
65
+ <div className="separator"></div>
66
+
67
+ <button
68
+ className={`menu-btn ${activeTab === 'achievement' ? 'active' : ''}`}
69
+ onClick={() => setActiveTab('achievement')}
70
+ >
71
+ My Achievement
72
+ </button>
73
+ <div className="separator"></div>
74
+
75
+ <button
76
+ className={`menu-btn ${activeTab === 'records' ? 'active' : ''}`}
77
+ onClick={() => setActiveTab('records')}
78
+ >
79
+ My Records
80
+ </button>
81
+ <div className="separator"></div>
82
+
83
+ <button
84
+ className={`menu-btn ${activeTab === 'customise' ? 'active' : ''}`}
85
+ onClick={() => setActiveTab('customise')}
86
+ >
87
+ Customise
88
+ </button>
89
+ <div className="separator"></div>
90
+
91
+ <button
92
+ className={`menu-btn ${activeTab === 'help' ? 'active' : ''}`}
93
+ onClick={() => setActiveTab('help')}
94
+ >
95
+ Help
96
+ </button>
97
+ </nav>
98
+
99
+ {/* 页面内容区域:根据 activeTab 的值,渲染对应的组件 */}
100
+
101
+ {/* Home 页:我们需要把 setActiveTab 传给它,因为它的 'Start' 按钮要能跳转页面 */}
102
+ {activeTab === 'home' && <Home setActiveTab={setActiveTab} />}
103
+
104
+ {/* FocusPageLocal 保持常驻,避免切页中断视频/连接 - 使用本地处理,不依赖WebRTC */}
105
+ <FocusPageLocal
106
+ videoManager={videoManagerRef.current}
107
+ sessionResult={sessionResult}
108
+ setSessionResult={setSessionResult}
109
+ isActive={activeTab === 'focus'}
110
+ />
111
+ {activeTab === 'achievement' && <Achievement />}
112
+ {activeTab === 'records' && <Records />}
113
+ {activeTab === 'customise' && <Customise />}
114
+ {activeTab === 'help' && <Help />}
115
+ </div>
116
+ );
117
+ }
118
+
119
+ export default App;
src/assets/react.svg ADDED
src/components/Achievement.jsx ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import React, { useState, useEffect } from 'react';
2
+
3
+ function Achievement() {
4
+ const [stats, setStats] = useState({
5
+ total_sessions: 0,
6
+ total_focus_time: 0,
7
+ avg_focus_score: 0,
8
+ streak_days: 0
9
+ });
10
+ const [badges, setBadges] = useState([]);
11
+ const [loading, setLoading] = useState(true);
12
+
13
+ // 格式化时间显示
14
+ const formatTime = (seconds) => {
15
+ const hours = Math.floor(seconds / 3600);
16
+ const minutes = Math.floor((seconds % 3600) / 60);
17
+ if (hours > 0) return `${hours}h ${minutes}m`;
18
+ return `${minutes}m`;
19
+ };
20
+
21
+ // 加载统计数据
22
+ useEffect(() => {
23
+ fetch('/api/stats/summary')
24
+ .then(res => res.json())
25
+ .then(data => {
26
+ setStats(data);
27
+ calculateBadges(data);
28
+ setLoading(false);
29
+ })
30
+ .catch(err => {
31
+ console.error('Failed to load stats:', err);
32
+ setLoading(false);
33
+ });
34
+ }, []);
35
+
36
+ // 根据统计数据计算徽章
37
+ const calculateBadges = (data) => {
38
+ const earnedBadges = [];
39
+
40
+ // 首次会话徽章
41
+ if (data.total_sessions >= 1) {
42
+ earnedBadges.push({
43
+ id: 'first-session',
44
+ name: 'First Step',
45
+ description: 'Complete your first focus session',
46
+ icon: '🎯',
47
+ unlocked: true
48
+ });
49
+ }
50
+
51
+ // 10次会话徽章
52
+ if (data.total_sessions >= 10) {
53
+ earnedBadges.push({
54
+ id: 'ten-sessions',
55
+ name: 'Getting Started',
56
+ description: 'Complete 10 focus sessions',
57
+ icon: '⭐',
58
+ unlocked: true
59
+ });
60
+ }
61
+
62
+ // 50次会话徽章
63
+ if (data.total_sessions >= 50) {
64
+ earnedBadges.push({
65
+ id: 'fifty-sessions',
66
+ name: 'Dedicated',
67
+ description: 'Complete 50 focus sessions',
68
+ icon: '🏆',
69
+ unlocked: true
70
+ });
71
+ }
72
+
73
+ // 专注大师徽章 (平均专注度 > 80%)
74
+ if (data.avg_focus_score >= 0.8 && data.total_sessions >= 5) {
75
+ earnedBadges.push({
76
+ id: 'focus-master',
77
+ name: 'Focus Master',
78
+ description: 'Maintain 80%+ average focus score',
79
+ icon: '🧠',
80
+ unlocked: true
81
+ });
82
+ }
83
+
84
+ // 连续天数徽章
85
+ if (data.streak_days >= 7) {
86
+ earnedBadges.push({
87
+ id: 'week-streak',
88
+ name: 'Week Warrior',
89
+ description: '7 day streak',
90
+ icon: '🔥',
91
+ unlocked: true
92
+ });
93
+ }
94
+
95
+ if (data.streak_days >= 30) {
96
+ earnedBadges.push({
97
+ id: 'month-streak',
98
+ name: 'Month Master',
99
+ description: '30 day streak',
100
+ icon: '💎',
101
+ unlocked: true
102
+ });
103
+ }
104
+
105
+ // 时长徽章 (10小时+)
106
+ if (data.total_focus_time >= 36000) {
107
+ earnedBadges.push({
108
+ id: 'ten-hours',
109
+ name: 'Endurance',
110
+ description: '10+ hours total focus time',
111
+ icon: '⏱️',
112
+ unlocked: true
113
+ });
114
+ }
115
+
116
+ // 未解锁徽章(示例)
117
+ const allBadges = [
118
+ {
119
+ id: 'first-session',
120
+ name: 'First Step',
121
+ description: 'Complete your first focus session',
122
+ icon: '🎯',
123
+ unlocked: data.total_sessions >= 1
124
+ },
125
+ {
126
+ id: 'ten-sessions',
127
+ name: 'Getting Started',
128
+ description: 'Complete 10 focus sessions',
129
+ icon: '⭐',
130
+ unlocked: data.total_sessions >= 10
131
+ },
132
+ {
133
+ id: 'fifty-sessions',
134
+ name: 'Dedicated',
135
+ description: 'Complete 50 focus sessions',
136
+ icon: '🏆',
137
+ unlocked: data.total_sessions >= 50
138
+ },
139
+ {
140
+ id: 'focus-master',
141
+ name: 'Focus Master',
142
+ description: 'Maintain 80%+ average focus score',
143
+ icon: '🧠',
144
+ unlocked: data.avg_focus_score >= 0.8 && data.total_sessions >= 5
145
+ },
146
+ {
147
+ id: 'week-streak',
148
+ name: 'Week Warrior',
149
+ description: '7 day streak',
150
+ icon: '🔥',
151
+ unlocked: data.streak_days >= 7
152
+ },
153
+ {
154
+ id: 'month-streak',
155
+ name: 'Month Master',
156
+ description: '30 day streak',
157
+ icon: '💎',
158
+ unlocked: data.streak_days >= 30
159
+ },
160
+ {
161
+ id: 'ten-hours',
162
+ name: 'Endurance',
163
+ description: '10+ hours total focus time',
164
+ icon: '⏱️',
165
+ unlocked: data.total_focus_time >= 36000
166
+ },
167
+ {
168
+ id: 'hundred-sessions',
169
+ name: 'Centurion',
170
+ description: 'Complete 100 focus sessions',
171
+ icon: '👑',
172
+ unlocked: data.total_sessions >= 100
173
+ }
174
+ ];
175
+
176
+ setBadges(allBadges);
177
+ };
178
+
179
+ return (
180
+ <main id="page-c" className="page">
181
+ <h1 className="page-title">My Achievement</h1>
182
+
183
+ {loading ? (
184
+ <div style={{ textAlign: 'center', padding: '40px', color: '#888' }}>
185
+ Loading stats...
186
+ </div>
187
+ ) : (
188
+ <>
189
+ <div className="stats-grid">
190
+ <div className="stat-card">
191
+ <div className="stat-number" id="total-sessions">{stats.total_sessions}</div>
192
+ <div className="stat-label">Total Sessions</div>
193
+ </div>
194
+ <div className="stat-card">
195
+ <div className="stat-number" id="total-hours">{formatTime(stats.total_focus_time)}</div>
196
+ <div className="stat-label">Total Focus Time</div>
197
+ </div>
198
+ <div className="stat-card">
199
+ <div className="stat-number" id="avg-focus">{(stats.avg_focus_score * 100).toFixed(1)}%</div>
200
+ <div className="stat-label">Average Focus</div>
201
+ </div>
202
+ <div className="stat-card">
203
+ <div className="stat-number" id="current-streak">{stats.streak_days}</div>
204
+ <div className="stat-label">Day Streak</div>
205
+ </div>
206
+ </div>
207
+
208
+ <div className="achievements-section">
209
+ <h2>Badges</h2>
210
+ <div id="badges-container" className="badges-grid">
211
+ {badges.map(badge => (
212
+ <div
213
+ key={badge.id}
214
+ className={`badge ${badge.unlocked ? 'unlocked' : 'locked'}`}
215
+ style={{
216
+ padding: '20px',
217
+ textAlign: 'center',
218
+ border: '2px solid',
219
+ borderColor: badge.unlocked ? '#00FF00' : '#444',
220
+ borderRadius: '10px',
221
+ backgroundColor: badge.unlocked ? 'rgba(0, 255, 0, 0.1)' : 'rgba(68, 68, 68, 0.1)',
222
+ opacity: badge.unlocked ? 1 : 0.5,
223
+ transition: 'all 0.3s'
224
+ }}
225
+ >
226
+ <div style={{ fontSize: '48px', marginBottom: '10px' }}>
227
+ {badge.unlocked ? badge.icon : '🔒'}
228
+ </div>
229
+ <div style={{ fontWeight: 'bold', marginBottom: '5px' }}>
230
+ {badge.name}
231
+ </div>
232
+ <div style={{ fontSize: '12px', color: '#888' }}>
233
+ {badge.description}
234
+ </div>
235
+ </div>
236
+ ))}
237
+ </div>
238
+ </div>
239
+ </>
240
+ )}
241
+ </main>
242
+ );
243
+ }
244
+
245
+ export default Achievement;
src/components/Customise.jsx ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import React, { useState, useEffect, useRef } from 'react';
2
+
3
+ function Customise() {
4
+ const [sensitivity, setSensitivity] = useState(6);
5
+ const [frameRate, setFrameRate] = useState(30);
6
+ const [notificationsEnabled, setNotificationsEnabled] = useState(true);
7
+ const [threshold, setThreshold] = useState(30);
8
+
9
+ // 引用隐藏的文件输入框
10
+ const fileInputRef = useRef(null);
11
+
12
+ // 1. 加载设置
13
+ useEffect(() => {
14
+ fetch('/api/settings')
15
+ .then(res => res.json())
16
+ .then(data => {
17
+ if (data) {
18
+ if (data.sensitivity) setSensitivity(data.sensitivity);
19
+ if (data.frame_rate) setFrameRate(data.frame_rate);
20
+ if (data.notification_threshold) setThreshold(data.notification_threshold);
21
+ if (data.notification_enabled !== undefined) setNotificationsEnabled(data.notification_enabled);
22
+ }
23
+ })
24
+ .catch(err => console.error("Failed to load settings", err));
25
+ }, []);
26
+
27
+ // 2. 保存设置
28
+ const handleSave = async () => {
29
+ const settings = {
30
+ sensitivity: parseInt(sensitivity),
31
+ frame_rate: parseInt(frameRate),
32
+ notification_enabled: notificationsEnabled,
33
+ notification_threshold: parseInt(threshold)
34
+ };
35
+
36
+ try {
37
+ const response = await fetch('/api/settings', {
38
+ method: 'PUT',
39
+ headers: { 'Content-Type': 'application/json' },
40
+ body: JSON.stringify(settings)
41
+ });
42
+ if (response.ok) alert("Settings saved successfully!");
43
+ else alert("Failed to save settings.");
44
+ } catch (error) {
45
+ alert("Error saving settings: " + error.message);
46
+ }
47
+ };
48
+
49
+ // 3. 导出数据 (Export)
50
+ const handleExport = async () => {
51
+ try {
52
+ // 请求获取所有历史记录
53
+ const response = await fetch('/api/sessions?filter=all');
54
+ if (!response.ok) throw new Error("Failed to fetch data");
55
+
56
+ const data = await response.json();
57
+
58
+ // 创建 JSON Blob
59
+ const jsonString = JSON.stringify(data, null, 2);
60
+ const blob = new Blob([jsonString], { type: 'application/json' });
61
+
62
+ // 创建临时下载链接
63
+ const url = URL.createObjectURL(blob);
64
+ const link = document.createElement('a');
65
+ link.href = url;
66
+ // 文件名包含当前日期
67
+ link.download = `focus-guard-backup-${new Date().toISOString().slice(0, 10)}.json`;
68
+
69
+ // 触发下载
70
+ document.body.appendChild(link);
71
+ link.click();
72
+
73
+ // 清理
74
+ document.body.removeChild(link);
75
+ URL.revokeObjectURL(url);
76
+ } catch (error) {
77
+ console.error(error);
78
+ alert("Export failed: " + error.message);
79
+ }
80
+ };
81
+
82
+ // 4. 触发导入文件选择
83
+ const triggerImport = () => {
84
+ fileInputRef.current.click();
85
+ };
86
+
87
+ // 5. 处理文件导入 (Import)
88
+ const handleFileChange = async (event) => {
89
+ const file = event.target.files[0];
90
+ if (!file) return;
91
+
92
+ const reader = new FileReader();
93
+ reader.onload = async (e) => {
94
+ try {
95
+ const content = e.target.result;
96
+ const sessions = JSON.parse(content);
97
+
98
+ // 简单的验证:确保它是一个数组
99
+ if (!Array.isArray(sessions)) {
100
+ throw new Error("Invalid file format: Expected a list of sessions.");
101
+ }
102
+
103
+ // 发送给后端进行存储
104
+ const response = await fetch('/api/import', {
105
+ method: 'POST',
106
+ headers: { 'Content-Type': 'application/json' },
107
+ body: JSON.stringify(sessions)
108
+ });
109
+
110
+ if (response.ok) {
111
+ const result = await response.json();
112
+ alert(`Success! Imported ${result.count} sessions.`);
113
+ } else {
114
+ alert("Import failed on server side.");
115
+ }
116
+ } catch (err) {
117
+ alert("Error parsing file: " + err.message);
118
+ }
119
+ // 清空 input,允许重复上传同一个文件
120
+ event.target.value = '';
121
+ };
122
+ reader.readAsText(file);
123
+ };
124
+
125
+ // 6. 清除历史 (Clear History)
126
+ const handleClearHistory = async () => {
127
+ if (!window.confirm("Are you sure? This will delete ALL your session history permanently.")) {
128
+ return;
129
+ }
130
+
131
+ try {
132
+ const response = await fetch('/api/history', { method: 'DELETE' });
133
+ if (response.ok) {
134
+ alert("All history has been cleared.");
135
+ } else {
136
+ alert("Failed to clear history.");
137
+ }
138
+ } catch (err) {
139
+ alert("Error: " + err.message);
140
+ }
141
+ };
142
+
143
+ return (
144
+ <main id="page-e" className="page">
145
+ <h1 className="page-title">Customise</h1>
146
+
147
+ <div className="settings-container">
148
+ {/* Detection Settings */}
149
+ <div className="setting-group">
150
+ <h2>Detection Settings</h2>
151
+ <div className="setting-item">
152
+ <label htmlFor="sensitivity-slider">Detection Sensitivity</label>
153
+ <div className="slider-group">
154
+ <input type="range" id="sensitivity-slider" min="1" max="10" value={sensitivity} onChange={(e) => setSensitivity(e.target.value)} />
155
+ <span id="sensitivity-value">{sensitivity}</span>
156
+ </div>
157
+ <p className="setting-description">Higher values require stricter focus criteria</p>
158
+ </div>
159
+ <div className="setting-item">
160
+ <label htmlFor="default-framerate">Default Frame Rate</label>
161
+ <div className="slider-group">
162
+ <input type="range" id="default-framerate" min="5" max="60" value={frameRate} onChange={(e) => setFrameRate(e.target.value)} />
163
+ <span id="framerate-value">{frameRate}</span> FPS
164
+ </div>
165
+ </div>
166
+ </div>
167
+
168
+ {/* Notifications */}
169
+ <div className="setting-group">
170
+ <h2>Notifications</h2>
171
+ <div className="setting-item">
172
+ <label>
173
+ <input type="checkbox" id="enable-notifications" checked={notificationsEnabled} onChange={(e) => setNotificationsEnabled(e.target.checked)} />
174
+ Enable distraction notifications
175
+ </label>
176
+ </div>
177
+ <div className="setting-item">
178
+ <label htmlFor="notification-threshold">Alert after (seconds)</label>
179
+ <input type="number" id="notification-threshold" value={threshold} onChange={(e) => setThreshold(e.target.value)} min="5" max="300" />
180
+ </div>
181
+ </div>
182
+
183
+ {/* Data Management */}
184
+ <div className="setting-group">
185
+ <h2>Data Management</h2>
186
+
187
+ {/* 隐藏的文件输入框,只接受 json */}
188
+ <input
189
+ type="file"
190
+ ref={fileInputRef}
191
+ style={{ display: 'none' }}
192
+ accept=".json"
193
+ onChange={handleFileChange}
194
+ />
195
+
196
+ <div style={{ display: 'flex', gap: '10px', justifyContent: 'center', flexWrap: 'wrap' }}>
197
+ {/* Export 按钮 */}
198
+ <button id="export-data" className="action-btn blue" onClick={handleExport} style={{ width: '30%', minWidth: '120px' }}>
199
+ Export Data
200
+ </button>
201
+
202
+ {/* Import 按钮 */}
203
+ <button id="import-data" className="action-btn yellow" onClick={triggerImport} style={{ width: '30%', minWidth: '120px' }}>
204
+ Import Data
205
+ </button>
206
+
207
+ {/* Clear 按钮 */}
208
+ <button id="clear-history" className="action-btn red" onClick={handleClearHistory} style={{ width: '30%', minWidth: '120px' }}>
209
+ Clear History
210
+ </button>
211
+ </div>
212
+ </div>
213
+
214
+ <button id="save-settings" className="btn-main" onClick={handleSave}>Save Settings</button>
215
+ </div>
216
+ </main>
217
+ );
218
+ }
219
+
220
+ export default Customise;
src/components/FocusPage.jsx ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import React, { useState, useEffect } from 'react';
2
+
3
+ function FocusPage({ videoManager, sessionResult, setSessionResult, isActive, displayVideoRef }) {
4
+ const [currentFrame, setCurrentFrame] = useState(30);
5
+ const [timelineEvents, setTimelineEvents] = useState([]);
6
+
7
+ const videoRef = displayVideoRef;
8
+
9
+ // 辅助函数:格式化时间
10
+ const formatDuration = (seconds) => {
11
+ // 如果是 0,直接显示 0s (或者你可以保留原来的 0m 0s)
12
+ if (seconds === 0) return "0s";
13
+
14
+ const mins = Math.floor(seconds / 60);
15
+ const secs = Math.floor(seconds % 60);
16
+ return `${mins}m ${secs}s`;
17
+ };
18
+
19
+ useEffect(() => {
20
+ if (!videoManager) return;
21
+
22
+ // 设置回调函数来更新时间轴
23
+ const originalOnStatusUpdate = videoManager.callbacks.onStatusUpdate;
24
+ videoManager.callbacks.onStatusUpdate = (isFocused) => {
25
+ setTimelineEvents(prev => {
26
+ const newEvents = [...prev, { isFocused, timestamp: Date.now() }];
27
+ if (newEvents.length > 60) newEvents.shift();
28
+ return newEvents;
29
+ });
30
+ // 调用原始回调(如果有)
31
+ if (originalOnStatusUpdate) originalOnStatusUpdate(isFocused);
32
+ };
33
+
34
+ // 清理函数:不再自动停止session,只清理回调
35
+ return () => {
36
+ if (videoManager) {
37
+ videoManager.callbacks.onStatusUpdate = originalOnStatusUpdate;
38
+ }
39
+ };
40
+ }, [videoManager]);
41
+
42
+ const handleStart = async () => {
43
+ try {
44
+ if (videoManager) {
45
+ setSessionResult(null); // 开始时清除结果层
46
+ setTimelineEvents([]);
47
+
48
+ console.log('🎬 Initializing camera...');
49
+ await videoManager.initCamera(videoRef.current);
50
+ console.log('✅ Camera initialized');
51
+
52
+ console.log('🚀 Starting streaming...');
53
+ await videoManager.startStreaming();
54
+ console.log('✅ Streaming started successfully');
55
+ }
56
+ } catch (err) {
57
+ console.error('❌ Start error:', err);
58
+ let errorMessage = "Failed to start: ";
59
+
60
+ if (err.name === 'NotAllowedError') {
61
+ errorMessage += "Camera permission denied. Please allow camera access.";
62
+ } else if (err.name === 'NotFoundError') {
63
+ errorMessage += "No camera found. Please connect a camera.";
64
+ } else if (err.name === 'NotReadableError') {
65
+ errorMessage += "Camera is already in use by another application.";
66
+ } else if (err.message && err.message.includes('HTTPS')) {
67
+ errorMessage += "Camera requires HTTPS. Please use a secure connection.";
68
+ } else {
69
+ errorMessage += err.message || "Unknown error occurred.";
70
+ }
71
+
72
+ alert(errorMessage + "\n\nCheck browser console for details.");
73
+ }
74
+ };
75
+
76
+ const handleStop = () => {
77
+ if (videoManager) {
78
+ videoManager.stopStreaming();
79
+ }
80
+ };
81
+
82
+ const handlePiP = async () => {
83
+ try {
84
+ const sourceVideoEl = videoRef.current;
85
+ if (!sourceVideoEl) {
86
+ alert('Video not ready. Please click Start first.');
87
+ return;
88
+ }
89
+
90
+ if (document.pictureInPictureElement) {
91
+ await document.exitPictureInPicture();
92
+ return;
93
+ }
94
+
95
+ sourceVideoEl.disablePictureInPicture = false;
96
+
97
+ if (typeof sourceVideoEl.webkitSetPresentationMode === 'function') {
98
+ sourceVideoEl.play().catch(() => {});
99
+ sourceVideoEl.webkitSetPresentationMode('picture-in-picture');
100
+ return;
101
+ }
102
+
103
+ if (!document.pictureInPictureEnabled || typeof sourceVideoEl.requestPictureInPicture !== 'function') {
104
+ alert('Picture-in-Picture is not supported in this browser.');
105
+ return;
106
+ }
107
+
108
+ const pipPromise = sourceVideoEl.requestPictureInPicture();
109
+ sourceVideoEl.play().catch(() => {});
110
+ await pipPromise;
111
+ } catch (err) {
112
+ console.error('PiP error:', err);
113
+ alert('Failed to enter Picture-in-Picture.');
114
+ }
115
+ };
116
+
117
+ // 浮窗功能
118
+ const handleFloatingWindow = () => {
119
+ handlePiP();
120
+ };
121
+
122
+ // ==========================================
123
+ // 新增功能:预览按钮的处理函数
124
+ // ==========================================
125
+ const handlePreview = () => {
126
+ // 强制设置一个 0 分 0 秒的假数据,触发 overlay 显示
127
+ setSessionResult({
128
+ duration_seconds: 0,
129
+ focus_score: 0
130
+ });
131
+ };
132
+
133
+ const handleCloseOverlay = () => {
134
+ setSessionResult(null);
135
+ };
136
+ // ==========================================
137
+
138
+ const handleFrameChange = (val) => {
139
+ setCurrentFrame(val);
140
+ if (videoManager) {
141
+ videoManager.setFrameRate(val);
142
+ }
143
+ };
144
+
145
+ const pageStyle = isActive
146
+ ? undefined
147
+ : {
148
+ position: 'absolute',
149
+ width: '1px',
150
+ height: '1px',
151
+ overflow: 'hidden',
152
+ opacity: 0,
153
+ pointerEvents: 'none'
154
+ };
155
+
156
+ return (
157
+ <main id="page-b" className="page" style={pageStyle}>
158
+ {/* 1. Camera / Display Area */}
159
+ <section id="display-area" style={{ position: 'relative', overflow: 'hidden' }}>
160
+ <video
161
+ ref={videoRef}
162
+ muted
163
+ playsInline
164
+ autoPlay
165
+ style={{ width: '100%', height: '100%', objectFit: 'contain' }}
166
+ />
167
+
168
+ {/* 结果覆盖层 */}
169
+ {sessionResult && (
170
+ <div className="session-result-overlay">
171
+ <h3>Session Complete!</h3>
172
+ <div className="result-item">
173
+ <span className="label">Duration:</span>
174
+ <span className="value">{formatDuration(sessionResult.duration_seconds)}</span>
175
+ </div>
176
+ <div className="result-item">
177
+ <span className="label">Focus Score:</span>
178
+ <span className="value">{(sessionResult.focus_score * 100).toFixed(1)}%</span>
179
+ </div>
180
+
181
+ {/* 新增:加一个小按钮方便关闭预览 */}
182
+ <button
183
+ onClick={handleCloseOverlay}
184
+ style={{
185
+ marginTop: '20px',
186
+ padding: '8px 20px',
187
+ background: 'transparent',
188
+ border: '1px solid white',
189
+ color: 'white',
190
+ borderRadius: '20px',
191
+ cursor: 'pointer'
192
+ }}
193
+ >
194
+ Close
195
+ </button>
196
+ </div>
197
+ )}
198
+
199
+ </section>
200
+
201
+ {/* 2. Timeline Area */}
202
+ <section id="timeline-area">
203
+ <div className="timeline-label">Timeline</div>
204
+ <div id="timeline-visuals">
205
+ {timelineEvents.map((event, index) => (
206
+ <div
207
+ key={index}
208
+ className="timeline-block"
209
+ style={{
210
+ backgroundColor: event.isFocused ? '#00FF00' : '#FF0000',
211
+ width: '10px',
212
+ height: '20px',
213
+ display: 'inline-block',
214
+ marginRight: '2px',
215
+ borderRadius: '2px'
216
+ }}
217
+ title={event.isFocused ? 'Focused' : 'Distracted'}
218
+ />
219
+ ))}
220
+ </div>
221
+ <div id="timeline-line"></div>
222
+ </section>
223
+
224
+ {/* 3. Control Buttons */}
225
+ <section id="control-panel">
226
+ <button id="btn-cam-start" className="action-btn green" onClick={handleStart}>Start</button>
227
+ <button id="btn-floating" className="action-btn yellow" onClick={handleFloatingWindow}>Floating Window</button>
228
+
229
+ {/* 修改:把 Models 按钮暂时改成 Preview 按钮,或者加在它后面 */}
230
+ <button
231
+ id="btn-preview"
232
+ className="action-btn"
233
+ style={{ backgroundColor: '#6c5ce7' }} // 紫色按钮以示区别
234
+ onClick={handlePreview}
235
+ >
236
+ Preview Result
237
+ </button>
238
+
239
+ <button id="btn-cam-stop" className="action-btn red" onClick={handleStop}>Stop</button>
240
+ </section>
241
+
242
+ {/* 4. Frame Control */}
243
+ <section id="frame-control">
244
+ <label htmlFor="frame-slider">Frame</label>
245
+ <input
246
+ type="range"
247
+ id="frame-slider"
248
+ min="1"
249
+ max="60"
250
+ value={currentFrame}
251
+ onChange={(e) => handleFrameChange(e.target.value)}
252
+ />
253
+ <input
254
+ type="number"
255
+ id="frame-input"
256
+ value={currentFrame}
257
+ onChange={(e) => handleFrameChange(e.target.value)}
258
+ />
259
+ </section>
260
+ </main>
261
+ );
262
+ }
263
+
264
+ export default FocusPage;
src/components/FocusPageLocal.jsx ADDED
@@ -0,0 +1,454 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import React, { useState, useEffect, useRef } from 'react';
2
+
3
+ function FocusPageLocal({ videoManager, sessionResult, setSessionResult, isActive }) {
4
+ const [currentFrame, setCurrentFrame] = useState(15);
5
+ const [timelineEvents, setTimelineEvents] = useState([]);
6
+ const [stats, setStats] = useState(null);
7
+
8
+ const localVideoRef = useRef(null);
9
+ const displayCanvasRef = useRef(null);
10
+ const pipVideoRef = useRef(null); // 用于 PiP 的隐藏 video 元素
11
+ const pipStreamRef = useRef(null);
12
+
13
+ // 辅助函数:格式化时间
14
+ const formatDuration = (seconds) => {
15
+ if (seconds === 0) return "0s";
16
+ const mins = Math.floor(seconds / 60);
17
+ const secs = Math.floor(seconds % 60);
18
+ return `${mins}m ${secs}s`;
19
+ };
20
+
21
+ useEffect(() => {
22
+ if (!videoManager) return;
23
+
24
+ // 设置回调函数来更新时间轴
25
+ const originalOnStatusUpdate = videoManager.callbacks.onStatusUpdate;
26
+ videoManager.callbacks.onStatusUpdate = (isFocused) => {
27
+ setTimelineEvents(prev => {
28
+ const newEvents = [...prev, { isFocused, timestamp: Date.now() }];
29
+ if (newEvents.length > 60) newEvents.shift();
30
+ return newEvents;
31
+ });
32
+ if (originalOnStatusUpdate) originalOnStatusUpdate(isFocused);
33
+ };
34
+
35
+ // 定期更新统计信息
36
+ const statsInterval = setInterval(() => {
37
+ if (videoManager && videoManager.getStats) {
38
+ setStats(videoManager.getStats());
39
+ }
40
+ }, 1000);
41
+
42
+ return () => {
43
+ if (videoManager) {
44
+ videoManager.callbacks.onStatusUpdate = originalOnStatusUpdate;
45
+ }
46
+ clearInterval(statsInterval);
47
+ };
48
+ }, [videoManager]);
49
+
50
+ const handleStart = async () => {
51
+ try {
52
+ if (videoManager) {
53
+ setSessionResult(null);
54
+ setTimelineEvents([]);
55
+
56
+ console.log('Initializing local camera...');
57
+ await videoManager.initCamera(localVideoRef.current, displayCanvasRef.current);
58
+ console.log('Camera initialized');
59
+
60
+ console.log('Starting local streaming...');
61
+ await videoManager.startStreaming();
62
+ console.log('Streaming started successfully');
63
+ }
64
+ } catch (err) {
65
+ console.error('Start error:', err);
66
+ let errorMessage = "Failed to start: ";
67
+
68
+ if (err.name === 'NotAllowedError') {
69
+ errorMessage += "Camera permission denied. Please allow camera access.";
70
+ } else if (err.name === 'NotFoundError') {
71
+ errorMessage += "No camera found. Please connect a camera.";
72
+ } else if (err.name === 'NotReadableError') {
73
+ errorMessage += "Camera is already in use by another application.";
74
+ } else {
75
+ errorMessage += err.message || "Unknown error occurred.";
76
+ }
77
+
78
+ alert(errorMessage + "\n\nCheck browser console for details.");
79
+ }
80
+ };
81
+
82
+ const handleStop = async () => {
83
+ if (videoManager) {
84
+ videoManager.stopStreaming();
85
+ }
86
+ try {
87
+ if (document.pictureInPictureElement === pipVideoRef.current) {
88
+ await document.exitPictureInPicture();
89
+ }
90
+ } catch (_) {}
91
+ if (pipVideoRef.current) {
92
+ pipVideoRef.current.pause();
93
+ pipVideoRef.current.srcObject = null;
94
+ }
95
+ if (pipStreamRef.current) {
96
+ pipStreamRef.current.getTracks().forEach(t => t.stop());
97
+ pipStreamRef.current = null;
98
+ }
99
+ };
100
+
101
+ const handlePiP = async () => {
102
+ try {
103
+ // 检查是否有视频管理器和是否在运行
104
+ if (!videoManager || !videoManager.isStreaming) {
105
+ alert('Please start the video first.');
106
+ return;
107
+ }
108
+
109
+ if (!displayCanvasRef.current) {
110
+ alert('Video not ready.');
111
+ return;
112
+ }
113
+
114
+ // 如果已经在 PiP 模式,且是本视频,退出
115
+ if (document.pictureInPictureElement === pipVideoRef.current) {
116
+ await document.exitPictureInPicture();
117
+ console.log('PiP exited');
118
+ return;
119
+ }
120
+
121
+ // 检查浏览器支持
122
+ if (!document.pictureInPictureEnabled) {
123
+ alert('Picture-in-Picture is not supported in this browser.');
124
+ return;
125
+ }
126
+
127
+ // 创建或获取 PiP video 元素
128
+ const pipVideo = pipVideoRef.current;
129
+ if (!pipVideo) {
130
+ alert('PiP video element not ready.');
131
+ return;
132
+ }
133
+
134
+ const isSafariPiP = typeof pipVideo.webkitSetPresentationMode === 'function';
135
+
136
+ // 优先用画布流(带检测框),失败再回退到摄像头流
137
+ let stream = pipStreamRef.current;
138
+ if (!stream) {
139
+ const capture = displayCanvasRef.current.captureStream;
140
+ if (typeof capture === 'function') {
141
+ stream = capture.call(displayCanvasRef.current, 30);
142
+ }
143
+ if (!stream || stream.getTracks().length === 0) {
144
+ const cameraStream = localVideoRef.current?.srcObject;
145
+ if (!cameraStream) {
146
+ alert('Camera stream not ready.');
147
+ return;
148
+ }
149
+ stream = cameraStream;
150
+ }
151
+ pipStreamRef.current = stream;
152
+ }
153
+
154
+ // 确保��有轨道
155
+ if (!stream || stream.getTracks().length === 0) {
156
+ alert('Failed to capture video stream from canvas.');
157
+ return;
158
+ }
159
+
160
+ pipVideo.srcObject = stream;
161
+
162
+ // 播放视频(Safari 可能不会触发 onloadedmetadata)
163
+ if (pipVideo.readyState < 2) {
164
+ await new Promise((resolve) => {
165
+ const onReady = () => {
166
+ pipVideo.removeEventListener('loadeddata', onReady);
167
+ pipVideo.removeEventListener('canplay', onReady);
168
+ resolve();
169
+ };
170
+ pipVideo.addEventListener('loadeddata', onReady);
171
+ pipVideo.addEventListener('canplay', onReady);
172
+ // 兜底:短延迟后继续尝试
173
+ setTimeout(resolve, 600);
174
+ });
175
+ }
176
+
177
+ try {
178
+ await pipVideo.play();
179
+ } catch (_) {
180
+ // Safari 可能拒绝自动播放,但仍可进入 PiP
181
+ }
182
+
183
+ // Safari 支持(优先)
184
+ if (isSafariPiP) {
185
+ try {
186
+ pipVideo.webkitSetPresentationMode('picture-in-picture');
187
+ console.log('PiP activated (Safari)');
188
+ return;
189
+ } catch (e) {
190
+ // 如果画布流失败,回退到摄像头流再试一次
191
+ const cameraStream = localVideoRef.current?.srcObject;
192
+ if (cameraStream && cameraStream !== pipVideo.srcObject) {
193
+ pipVideo.srcObject = cameraStream;
194
+ try {
195
+ await pipVideo.play();
196
+ } catch (_) {}
197
+ pipVideo.webkitSetPresentationMode('picture-in-picture');
198
+ console.log('PiP activated (Safari fallback)');
199
+ return;
200
+ }
201
+ throw e;
202
+ }
203
+ }
204
+
205
+ // 标准 API
206
+ if (typeof pipVideo.requestPictureInPicture === 'function') {
207
+ await pipVideo.requestPictureInPicture();
208
+ console.log('PiP activated');
209
+ } else {
210
+ alert('Picture-in-Picture is not supported in this browser.');
211
+ }
212
+
213
+ } catch (err) {
214
+ console.error('PiP error:', err);
215
+ alert('Failed to enter Picture-in-Picture: ' + err.message);
216
+ }
217
+ };
218
+
219
+ const handleFloatingWindow = () => {
220
+ handlePiP();
221
+ };
222
+
223
+ const handleFrameChange = (val) => {
224
+ const rate = parseInt(val);
225
+ setCurrentFrame(rate);
226
+ if (videoManager) {
227
+ videoManager.setFrameRate(rate);
228
+ }
229
+ };
230
+
231
+ const handlePreview = () => {
232
+ if (!videoManager || !videoManager.isStreaming) {
233
+ alert('Please start a session first.');
234
+ return;
235
+ }
236
+
237
+ // 获取当前统计数据
238
+ const currentStats = videoManager.getStats();
239
+
240
+ if (!currentStats.sessionId) {
241
+ alert('No active session.');
242
+ return;
243
+ }
244
+
245
+ // 计算当前持续时间(从 session 开始到现在)
246
+ const sessionDuration = Math.floor((Date.now() - (videoManager.sessionStartTime || Date.now())) / 1000);
247
+
248
+ // 计算当前专注分数
249
+ const focusScore = currentStats.framesProcessed > 0
250
+ ? (currentStats.framesProcessed * (currentStats.currentStatus ? 1 : 0)) / currentStats.framesProcessed
251
+ : 0;
252
+
253
+ // 显示当前实时数据
254
+ setSessionResult({
255
+ duration_seconds: sessionDuration,
256
+ focus_score: focusScore,
257
+ total_frames: currentStats.framesProcessed,
258
+ focused_frames: Math.floor(currentStats.framesProcessed * focusScore)
259
+ });
260
+ };
261
+
262
+ const handleCloseOverlay = () => {
263
+ setSessionResult(null);
264
+ };
265
+
266
+ const pageStyle = isActive
267
+ ? undefined
268
+ : {
269
+ position: 'absolute',
270
+ width: '1px',
271
+ height: '1px',
272
+ overflow: 'hidden',
273
+ opacity: 0,
274
+ pointerEvents: 'none'
275
+ };
276
+
277
+ useEffect(() => {
278
+ return () => {
279
+ if (pipVideoRef.current) {
280
+ pipVideoRef.current.pause();
281
+ pipVideoRef.current.srcObject = null;
282
+ }
283
+ if (pipStreamRef.current) {
284
+ pipStreamRef.current.getTracks().forEach(t => t.stop());
285
+ pipStreamRef.current = null;
286
+ }
287
+ };
288
+ }, []);
289
+
290
+ return (
291
+ <main id="page-b" className="page" style={pageStyle}>
292
+ {/* 1. Camera / Display Area */}
293
+ <section id="display-area" style={{ position: 'relative', overflow: 'hidden' }}>
294
+ {/* 用于 PiP 的隐藏 video 元素(保持在 DOM 以提高兼容性) */}
295
+ <video
296
+ ref={pipVideoRef}
297
+ muted
298
+ playsInline
299
+ autoPlay
300
+ style={{
301
+ position: 'absolute',
302
+ width: '1px',
303
+ height: '1px',
304
+ opacity: 0,
305
+ pointerEvents: 'none'
306
+ }}
307
+ />
308
+ {/* 本地视频流(隐藏,仅用于截图) */}
309
+ <video
310
+ ref={localVideoRef}
311
+ muted
312
+ playsInline
313
+ autoPlay
314
+ style={{ display: 'none' }}
315
+ />
316
+
317
+ {/* 显示处理后的视频(使用 Canvas) */}
318
+ <canvas
319
+ ref={displayCanvasRef}
320
+ width={640}
321
+ height={480}
322
+ style={{
323
+ width: '100%',
324
+ height: '100%',
325
+ objectFit: 'contain',
326
+ backgroundColor: '#000'
327
+ }}
328
+ />
329
+
330
+ {/* 结果覆盖层 */}
331
+ {sessionResult && (
332
+ <div className="session-result-overlay">
333
+ <h3>Session Complete!</h3>
334
+ <div className="result-item">
335
+ <span className="label">Duration:</span>
336
+ <span className="value">{formatDuration(sessionResult.duration_seconds)}</span>
337
+ </div>
338
+ <div className="result-item">
339
+ <span className="label">Focus Score:</span>
340
+ <span className="value">{(sessionResult.focus_score * 100).toFixed(1)}%</span>
341
+ </div>
342
+
343
+ <button
344
+ onClick={handleCloseOverlay}
345
+ style={{
346
+ marginTop: '20px',
347
+ padding: '8px 20px',
348
+ background: 'transparent',
349
+ border: '1px solid white',
350
+ color: 'white',
351
+ borderRadius: '20px',
352
+ cursor: 'pointer'
353
+ }}
354
+ >
355
+ Close
356
+ </button>
357
+ </div>
358
+ )}
359
+
360
+ {/* 性能统计显示(开发模式) */}
361
+ {stats && stats.isStreaming && (
362
+ <div style={{
363
+ position: 'absolute',
364
+ top: '10px',
365
+ right: '10px',
366
+ background: 'rgba(0,0,0,0.7)',
367
+ color: 'white',
368
+ padding: '10px',
369
+ borderRadius: '5px',
370
+ fontSize: '12px',
371
+ fontFamily: 'monospace'
372
+ }}>
373
+ <div>Session: {stats.sessionId}</div>
374
+ <div>Sent: {stats.framesSent}</div>
375
+ <div>Processed: {stats.framesProcessed}</div>
376
+ <div>Latency: {stats.avgLatency.toFixed(0)}ms</div>
377
+ <div>Status: {stats.currentStatus ? 'Focused' : 'Not Focused'}</div>
378
+ <div>Confidence: {(stats.lastConfidence * 100).toFixed(1)}%</div>
379
+ </div>
380
+ )}
381
+ </section>
382
+
383
+ {/* 2. Timeline Area */}
384
+ <section id="timeline-area">
385
+ <div className="timeline-label">Timeline</div>
386
+ <div id="timeline-visuals">
387
+ {timelineEvents.map((event, index) => (
388
+ <div
389
+ key={index}
390
+ className="timeline-block"
391
+ style={{
392
+ backgroundColor: event.isFocused ? '#00FF00' : '#FF0000',
393
+ width: '10px',
394
+ height: '20px',
395
+ display: 'inline-block',
396
+ marginRight: '2px',
397
+ borderRadius: '2px'
398
+ }}
399
+ title={event.isFocused ? 'Focused' : 'Distracted'}
400
+ />
401
+ ))}
402
+ </div>
403
+ <div id="timeline-line"></div>
404
+ </section>
405
+
406
+ {/* 3. Control Buttons */}
407
+ <section id="control-panel">
408
+ <button id="btn-cam-start" className="action-btn green" onClick={handleStart}>
409
+ Start
410
+ </button>
411
+
412
+ <button id="btn-floating" className="action-btn yellow" onClick={handleFloatingWindow}>
413
+ Floating Window
414
+ </button>
415
+
416
+ <button
417
+ id="btn-preview"
418
+ className="action-btn"
419
+ style={{ backgroundColor: '#6c5ce7' }}
420
+ onClick={handlePreview}
421
+ >
422
+ Preview Result
423
+ </button>
424
+
425
+ <button id="btn-cam-stop" className="action-btn red" onClick={handleStop}>
426
+ Stop
427
+ </button>
428
+ </section>
429
+
430
+ {/* 4. Frame Control */}
431
+ <section id="frame-control">
432
+ <label htmlFor="frame-slider">Frame Rate (FPS)</label>
433
+ <input
434
+ type="range"
435
+ id="frame-slider"
436
+ min="5"
437
+ max="30"
438
+ value={currentFrame}
439
+ onChange={(e) => handleFrameChange(e.target.value)}
440
+ />
441
+ <input
442
+ type="number"
443
+ id="frame-input"
444
+ min="5"
445
+ max="30"
446
+ value={currentFrame}
447
+ onChange={(e) => handleFrameChange(e.target.value)}
448
+ />
449
+ </section>
450
+ </main>
451
+ );
452
+ }
453
+
454
+ export default FocusPageLocal;
src/components/Help.jsx ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import React from 'react';
2
+
3
+ function Help() {
4
+ return (
5
+ <main id="page-f" className="page">
6
+ <h1 className="page-title">Help</h1>
7
+
8
+ <a
9
+ className="fake-ad"
10
+ href="https://www.kcl.ac.uk/"
11
+ target="_blank"
12
+ rel="noreferrer"
13
+ aria-label="Sponsored: King's College London"
14
+ >
15
+ <div className="fake-ad-badge">Sponsored</div>
16
+ <img src="/assets/111.jpg" alt="King's College London campus sign" />
17
+ <div className="fake-ad-cta">Learn More</div>
18
+ </a>
19
+
20
+ <div className="help-container">
21
+ <section className="help-section">
22
+ <h2>How to Use Focus Guard</h2>
23
+ <ol>
24
+ <li>Click "Start" or navigate to "Start Focus" in the menu</li>
25
+ <li>Allow camera access when prompted</li>
26
+ <li>Click the green "Start" button to begin monitoring</li>
27
+ <li>Position yourself in front of the camera</li>
28
+ <li>The system will track your focus in real-time</li>
29
+ <li>Click "Stop" when you're done to save the session</li>
30
+ </ol>
31
+ </section>
32
+
33
+ <section className="help-section">
34
+ <h2>What is "Focused"?</h2>
35
+ <p>The system considers you focused when:</p>
36
+ <ul>
37
+ <li>You are clearly visible in the camera frame</li>
38
+ <li>You are centered in the view</li>
39
+ <li>Your face is directed toward the screen</li>
40
+ <li>No other people are detected in the frame</li>
41
+ </ul>
42
+ </section>
43
+
44
+ <section className="help-section">
45
+ <h2>Adjusting Settings</h2>
46
+ <p><strong>Frame Rate:</strong> Lower values reduce CPU usage but update less frequently. Recommended: 15-30 FPS.</p>
47
+ <p><strong>Sensitivity:</strong> Higher values require stricter focus criteria. Adjust based on your setup.</p>
48
+ </section>
49
+
50
+ <section className="help-section">
51
+ <h2>Privacy & Data</h2>
52
+ <p>All video processing happens in real-time. No video frames are stored - only detection metadata (focus status, timestamps) is saved in your local database.</p>
53
+ </section>
54
+
55
+ <section className="help-section">
56
+ <h2>FAQ</h2>
57
+ <details>
58
+ <summary>Why is my focus score low?</summary>
59
+ <p>Ensure good lighting, center yourself in the camera frame, and adjust sensitivity settings in the Customise page.</p>
60
+ </details>
61
+ <details>
62
+ <summary>Can I use this without a camera?</summary>
63
+ <p>No, camera access is required for focus detection.</p>
64
+ </details>
65
+ <details>
66
+ <summary>Does this work on mobile?</summary>
67
+ <p>The app works on mobile browsers but performance may vary due to processing requirements.</p>
68
+ </details>
69
+ <details>
70
+ <summary>Is my data private?</summary>
71
+ <p>Yes! All processing happens locally. Video frames are analyzed in real-time and never stored. Only metadata is saved.</p>
72
+ </details>
73
+ </section>
74
+
75
+ <section className="help-section">
76
+ <h2>Technical Info</h2>
77
+ <p><strong>Model:</strong> YOLOv8n (Nano)</p>
78
+ <p><strong>Detection:</strong> Real-time person detection with pose analysis</p>
79
+ <p><strong>Storage:</strong> SQLite local database</p>
80
+ <p><strong>Framework:</strong> FastAPI + Native JavaScript</p>
81
+ </section>
82
+ </div>
83
+ </main>
84
+ );
85
+ }
86
+
87
+ export default Help;
src/components/Home.jsx ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import React from 'react';
2
+
3
+ // receive setActiveTab to change page
4
+ function Home({ setActiveTab }) {
5
+ return (
6
+ <main id="page-a" className="page">
7
+ <h1>FocusGuard</h1>
8
+ <p>Your productivity monitor assistant.</p>
9
+
10
+ {/* click button change to 'focus' page */}
11
+ <button
12
+ id="start-button"
13
+ className="btn-main"
14
+ onClick={() => setActiveTab('focus')}
15
+ >
16
+ Start
17
+ </button>
18
+ </main>
19
+ );
20
+ }
21
+
22
+ export default Home;
src/components/Records.jsx ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import React, { useState, useEffect, useRef } from 'react';
2
+
3
+ function Records() {
4
+ const [filter, setFilter] = useState('all');
5
+ const [sessions, setSessions] = useState([]);
6
+ const [loading, setLoading] = useState(false);
7
+ const chartRef = useRef(null);
8
+
9
+ // 格式化时间
10
+ const formatDuration = (seconds) => {
11
+ const mins = Math.floor(seconds / 60);
12
+ const secs = seconds % 60;
13
+ return `${mins}m ${secs}s`;
14
+ };
15
+
16
+ // 格式化日期
17
+ const formatDate = (dateString) => {
18
+ const date = new Date(dateString);
19
+ return date.toLocaleDateString('en-US', {
20
+ month: 'short',
21
+ day: 'numeric',
22
+ hour: '2-digit',
23
+ minute: '2-digit'
24
+ });
25
+ };
26
+
27
+ // 加载会话数据
28
+ const loadSessions = async (filterType) => {
29
+ setLoading(true);
30
+ try {
31
+ const response = await fetch(`/api/sessions?filter=${filterType}&limit=50`);
32
+ const data = await response.json();
33
+ setSessions(data);
34
+ drawChart(data);
35
+ } catch (error) {
36
+ console.error('Failed to load sessions:', error);
37
+ } finally {
38
+ setLoading(false);
39
+ }
40
+ };
41
+
42
+ // 绘制图表
43
+ const drawChart = (data) => {
44
+ const canvas = chartRef.current;
45
+ if (!canvas) return;
46
+
47
+ const ctx = canvas.getContext('2d');
48
+ const width = canvas.width = canvas.offsetWidth;
49
+ const height = canvas.height = 300;
50
+
51
+ // 清空画布
52
+ ctx.clearRect(0, 0, width, height);
53
+
54
+ if (data.length === 0) {
55
+ ctx.fillStyle = '#999';
56
+ ctx.font = '16px Nunito';
57
+ ctx.textAlign = 'center';
58
+ ctx.fillText('No data available', width / 2, height / 2);
59
+ return;
60
+ }
61
+
62
+ // 准备数据 (最多显示最近20个会话)
63
+ const displayData = data.slice(0, 20).reverse();
64
+ const padding = 50;
65
+ const chartWidth = width - padding * 2;
66
+ const chartHeight = height - padding * 2;
67
+ const barWidth = chartWidth / displayData.length;
68
+
69
+ // 找到最大值用于缩放
70
+ const maxScore = 1.0;
71
+
72
+ // 绘制坐标轴
73
+ ctx.strokeStyle = '#E0E0E0';
74
+ ctx.lineWidth = 2;
75
+ ctx.beginPath();
76
+ ctx.moveTo(padding, padding);
77
+ ctx.lineTo(padding, height - padding);
78
+ ctx.lineTo(width - padding, height - padding);
79
+ ctx.stroke();
80
+
81
+ // 绘制Y轴刻度
82
+ ctx.fillStyle = '#666';
83
+ ctx.font = '12px Nunito';
84
+ ctx.textAlign = 'right';
85
+ for (let i = 0; i <= 4; i++) {
86
+ const y = height - padding - (chartHeight * i / 4);
87
+ const value = (maxScore * i / 4 * 100).toFixed(0);
88
+ ctx.fillText(value + '%', padding - 10, y + 4);
89
+
90
+ // 绘制网格线
91
+ ctx.strokeStyle = '#F0F0F0';
92
+ ctx.lineWidth = 1;
93
+ ctx.beginPath();
94
+ ctx.moveTo(padding, y);
95
+ ctx.lineTo(width - padding, y);
96
+ ctx.stroke();
97
+ }
98
+
99
+ // 绘制柱状图
100
+ displayData.forEach((session, index) => {
101
+ const barHeight = (session.focus_score / maxScore) * chartHeight;
102
+ const x = padding + index * barWidth + barWidth * 0.1;
103
+ const y = height - padding - barHeight;
104
+ const barActualWidth = barWidth * 0.8;
105
+
106
+ // 根据分数设置颜色 - 使用蓝色主题
107
+ const score = session.focus_score;
108
+ let color;
109
+ if (score >= 0.8) color = '#4A90E2';
110
+ else if (score >= 0.6) color = '#5DADE2';
111
+ else if (score >= 0.4) color = '#85C1E9';
112
+ else color = '#AED6F1';
113
+
114
+ ctx.fillStyle = color;
115
+ ctx.fillRect(x, y, barActualWidth, barHeight);
116
+
117
+ // 绘制边框
118
+ ctx.strokeStyle = color;
119
+ ctx.lineWidth = 1;
120
+ ctx.strokeRect(x, y, barActualWidth, barHeight);
121
+ });
122
+
123
+ // 绘制图例
124
+ ctx.textAlign = 'left';
125
+ ctx.font = 'bold 14px Nunito';
126
+ ctx.fillStyle = '#4A90E2';
127
+ ctx.fillText('Focus Score by Session', padding, 30);
128
+ };
129
+
130
+ // 初始加载
131
+ useEffect(() => {
132
+ loadSessions(filter);
133
+ }, [filter]);
134
+
135
+ // 处理筛选按钮点击
136
+ const handleFilterClick = (filterType) => {
137
+ setFilter(filterType);
138
+ };
139
+
140
+ // 查看详情
141
+ const handleViewDetails = (sessionId) => {
142
+ // 这里可以实现查看详情的功能,比如弹窗显示该会话的详细信息
143
+ alert(`View details for session ${sessionId}\n(Feature can be extended later)`);
144
+ };
145
+
146
+ return (
147
+ <main id="page-d" className="page">
148
+ <h1 className="page-title">My Records</h1>
149
+
150
+ <div className="records-controls" style={{ display: 'flex', justifyContent: 'center', gap: '10px', marginBottom: '30px' }}>
151
+ <button
152
+ id="filter-today"
153
+ onClick={() => handleFilterClick('today')}
154
+ style={{
155
+ padding: '10px 30px',
156
+ borderRadius: '8px',
157
+ border: filter === 'today' ? 'none' : '2px solid #4A90E2',
158
+ background: filter === 'today' ? '#4A90E2' : 'transparent',
159
+ color: filter === 'today' ? 'white' : '#4A90E2',
160
+ fontSize: '14px',
161
+ fontWeight: '500',
162
+ cursor: 'pointer',
163
+ transition: 'all 0.3s'
164
+ }}
165
+ >
166
+ Today
167
+ </button>
168
+ <button
169
+ id="filter-week"
170
+ onClick={() => handleFilterClick('week')}
171
+ style={{
172
+ padding: '10px 30px',
173
+ borderRadius: '8px',
174
+ border: filter === 'week' ? 'none' : '2px solid #4A90E2',
175
+ background: filter === 'week' ? '#4A90E2' : 'transparent',
176
+ color: filter === 'week' ? 'white' : '#4A90E2',
177
+ fontSize: '14px',
178
+ fontWeight: '500',
179
+ cursor: 'pointer',
180
+ transition: 'all 0.3s'
181
+ }}
182
+ >
183
+ This Week
184
+ </button>
185
+ <button
186
+ id="filter-month"
187
+ onClick={() => handleFilterClick('month')}
188
+ style={{
189
+ padding: '10px 30px',
190
+ borderRadius: '8px',
191
+ border: filter === 'month' ? 'none' : '2px solid #4A90E2',
192
+ background: filter === 'month' ? '#4A90E2' : 'transparent',
193
+ color: filter === 'month' ? 'white' : '#4A90E2',
194
+ fontSize: '14px',
195
+ fontWeight: '500',
196
+ cursor: 'pointer',
197
+ transition: 'all 0.3s'
198
+ }}
199
+ >
200
+ This Month
201
+ </button>
202
+ <button
203
+ id="filter-all"
204
+ onClick={() => handleFilterClick('all')}
205
+ style={{
206
+ padding: '10px 30px',
207
+ borderRadius: '8px',
208
+ border: filter === 'all' ? 'none' : '2px solid #4A90E2',
209
+ background: filter === 'all' ? '#4A90E2' : 'transparent',
210
+ color: filter === 'all' ? 'white' : '#4A90E2',
211
+ fontSize: '14px',
212
+ fontWeight: '500',
213
+ cursor: 'pointer',
214
+ transition: 'all 0.3s'
215
+ }}
216
+ >
217
+ All Time
218
+ </button>
219
+ </div>
220
+
221
+ <div className="chart-container" style={{
222
+ background: 'white',
223
+ padding: '20px',
224
+ borderRadius: '10px',
225
+ marginBottom: '30px',
226
+ boxShadow: '0 2px 8px rgba(0,0,0,0.1)'
227
+ }}>
228
+ <canvas ref={chartRef} id="focus-chart" style={{ width: '100%', height: '300px' }}></canvas>
229
+ </div>
230
+
231
+ <div className="sessions-list" style={{
232
+ background: 'white',
233
+ padding: '20px',
234
+ borderRadius: '10px',
235
+ boxShadow: '0 2px 8px rgba(0,0,0,0.1)'
236
+ }}>
237
+ <h2 style={{ color: '#333', marginBottom: '20px', fontSize: '18px', fontWeight: '600' }}>Recent Sessions</h2>
238
+ {loading ? (
239
+ <div style={{ textAlign: 'center', padding: '40px', color: '#999' }}>
240
+ Loading sessions...
241
+ </div>
242
+ ) : sessions.length === 0 ? (
243
+ <div style={{ textAlign: 'center', padding: '40px', color: '#999' }}>
244
+ No sessions found for this period.
245
+ </div>
246
+ ) : (
247
+ <table id="sessions-table" style={{ width: '100%', borderCollapse: 'collapse', borderRadius: '10px', overflow: 'hidden' }}>
248
+ <thead>
249
+ <tr style={{ background: '#4A90E2' }}>
250
+ <th style={{ padding: '15px', textAlign: 'left', color: 'white', fontWeight: '600', fontSize: '14px' }}>Date</th>
251
+ <th style={{ padding: '15px', textAlign: 'center', color: 'white', fontWeight: '600', fontSize: '14px' }}>Duration</th>
252
+ <th style={{ padding: '15px', textAlign: 'center', color: 'white', fontWeight: '600', fontSize: '14px' }}>Focus Score</th>
253
+ <th style={{ padding: '15px', textAlign: 'center', color: 'white', fontWeight: '600', fontSize: '14px' }}>Action</th>
254
+ </tr>
255
+ </thead>
256
+ <tbody id="sessions-tbody">
257
+ {sessions.map((session, index) => (
258
+ <tr key={session.id} style={{
259
+ background: index % 2 === 0 ? '#f8f9fa' : 'white',
260
+ borderBottom: '1px solid #e9ecef'
261
+ }}>
262
+ <td style={{ padding: '15px', color: '#333', fontSize: '13px' }}>{formatDate(session.start_time)}</td>
263
+ <td style={{ padding: '15px', textAlign: 'center', color: '#333', fontSize: '13px' }}>{formatDuration(session.duration_seconds)}</td>
264
+ <td style={{ padding: '15px', textAlign: 'center' }}>
265
+ <span
266
+ style={{
267
+ color:
268
+ session.focus_score >= 0.8
269
+ ? '#28a745'
270
+ : session.focus_score >= 0.6
271
+ ? '#ffc107'
272
+ : session.focus_score >= 0.4
273
+ ? '#fd7e14'
274
+ : '#dc3545',
275
+ fontWeight: '600',
276
+ fontSize: '13px'
277
+ }}
278
+ >
279
+ {(session.focus_score * 100).toFixed(1)}%
280
+ </span>
281
+ </td>
282
+ <td style={{ padding: '15px', textAlign: 'center' }}>
283
+ <button
284
+ onClick={() => handleViewDetails(session.id)}
285
+ style={{
286
+ padding: '6px 20px',
287
+ background: '#4A90E2',
288
+ border: 'none',
289
+ color: 'white',
290
+ borderRadius: '6px',
291
+ cursor: 'pointer',
292
+ fontSize: '12px',
293
+ fontWeight: '500',
294
+ transition: 'background 0.3s'
295
+ }}
296
+ onMouseOver={(e) => e.target.style.background = '#357ABD'}
297
+ onMouseOut={(e) => e.target.style.background = '#4A90E2'}
298
+ >
299
+ View
300
+ </button>
301
+ </td>
302
+ </tr>
303
+ ))}
304
+ </tbody>
305
+ </table>
306
+ )}
307
+ </div>
308
+ </main>
309
+ );
310
+ }
311
+
312
+ export default Records;
src/index.css ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ :root {
2
+ font-family: system-ui, Avenir, Helvetica, Arial, sans-serif;
3
+ line-height: 1.5;
4
+ font-weight: 400;
5
+
6
+ color-scheme: light dark;
7
+ color: rgba(255, 255, 255, 0.87);
8
+ background-color: #242424;
9
+
10
+ font-synthesis: none;
11
+ text-rendering: optimizeLegibility;
12
+ -webkit-font-smoothing: antialiased;
13
+ -moz-osx-font-smoothing: grayscale;
14
+ }
15
+
16
+ a {
17
+ font-weight: 500;
18
+ color: #646cff;
19
+ text-decoration: inherit;
20
+ }
21
+
22
+ a:hover {
23
+ color: #535bf2;
24
+ }
25
+
26
+ body {
27
+ margin: 0;
28
+ display: flex;
29
+ place-items: center;
30
+ min-width: 320px;
31
+ min-height: 100vh;
32
+ }
33
+
34
+ h1 {
35
+ font-size: 3.2em;
36
+ line-height: 1.1;
37
+ }
38
+
39
+ button {
40
+ border-radius: 8px;
41
+ border: 1px solid transparent;
42
+ padding: 0.6em 1.2em;
43
+ font-size: 1em;
44
+ font-weight: 500;
45
+ font-family: inherit;
46
+ background-color: #1a1a1a;
47
+ cursor: pointer;
48
+ transition: border-color 0.25s;
49
+ }
50
+
51
+ button:hover {
52
+ border-color: #646cff;
53
+ }
54
+
55
+ button:focus,
56
+ button:focus-visible {
57
+ outline: 4px auto -webkit-focus-ring-color;
58
+ }
59
+
60
+ @media (prefers-color-scheme: light) {
61
+ :root {
62
+ color: #213547;
63
+ background-color: #ffffff;
64
+ }
65
+
66
+ a:hover {
67
+ color: #747bff;
68
+ }
69
+
70
+ button {
71
+ background-color: #f9f9f9;
72
+ }
73
+ }
src/main.jsx ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import { StrictMode } from 'react'
2
+ import { createRoot } from 'react-dom/client'
3
+ import './index.css'
4
+ import App from './App.jsx'
5
+
6
+ createRoot(document.getElementById('root')).render(
7
+ <StrictMode>
8
+ <App />
9
+ </StrictMode>,
10
+ )
src/utils/VideoManager.js ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // src/utils/VideoManager.js
2
+
3
+ export class VideoManager {
4
+ constructor(callbacks) {
5
+ // callbacks 用于通知 React 组件更新界面
6
+ // 例如: onStatusUpdate, onSessionStart, onSessionEnd
7
+ this.callbacks = callbacks || {};
8
+
9
+ this.videoElement = null; // 显示后端处理后的视频
10
+ this.stream = null; // 本地摄像头流
11
+ this.pc = null;
12
+ this.dataChannel = null;
13
+
14
+ this.isStreaming = false;
15
+ this.sessionId = null;
16
+ this.frameRate = 30;
17
+
18
+ // 状态平滑处理
19
+ this.currentStatus = false;
20
+ this.statusBuffer = [];
21
+ this.bufferSize = 5;
22
+
23
+ // 检测数据
24
+ this.latestDetectionData = null;
25
+ this.lastConfidence = 0;
26
+ this.detectionHoldMs = 30;
27
+
28
+ // 通知系统
29
+ this.notificationEnabled = true;
30
+ this.notificationThreshold = 30; // 默认30秒
31
+ this.unfocusedStartTime = null;
32
+ this.lastNotificationTime = null;
33
+ this.notificationCooldown = 60000; // 通知冷却时间60秒
34
+ }
35
+
36
+ // 初始化:获取摄像头流,并记录展示视频的元素
37
+ async initCamera(videoRef) {
38
+ try {
39
+ this.stream = await navigator.mediaDevices.getUserMedia({
40
+ video: {
41
+ width: { ideal: 640 },
42
+ height: { ideal: 480 },
43
+ facingMode: 'user'
44
+ },
45
+ audio: false
46
+ });
47
+
48
+ this.videoElement = videoRef;
49
+ return true;
50
+ } catch (error) {
51
+ console.error('Camera init error:', error);
52
+ throw error;
53
+ }
54
+ }
55
+
56
+ async startStreaming() {
57
+ if (!this.stream) {
58
+ console.error('❌ No stream available');
59
+ throw new Error('Camera stream not initialized');
60
+ }
61
+ this.isStreaming = true;
62
+
63
+ console.log('📹 Starting streaming...');
64
+
65
+ // 请求通知权限
66
+ await this.requestNotificationPermission();
67
+ // 加载通知设置
68
+ await this.loadNotificationSettings();
69
+
70
+ this.pc = new RTCPeerConnection({
71
+ iceServers: [
72
+ { urls: 'stun:stun.l.google.com:19302' },
73
+ { urls: 'stun:stun1.l.google.com:19302' },
74
+ { urls: 'stun:stun2.l.google.com:19302' },
75
+ { urls: 'stun:stun3.l.google.com:19302' },
76
+ { urls: 'stun:stun4.l.google.com:19302' }
77
+ ],
78
+ iceCandidatePoolSize: 10
79
+ });
80
+
81
+ // 添加连接状态监控
82
+ this.pc.onconnectionstatechange = () => {
83
+ console.log('🔗 Connection state:', this.pc.connectionState);
84
+ };
85
+
86
+ this.pc.oniceconnectionstatechange = () => {
87
+ console.log('🧊 ICE connection state:', this.pc.iceConnectionState);
88
+ };
89
+
90
+ this.pc.onicegatheringstatechange = () => {
91
+ console.log('📡 ICE gathering state:', this.pc.iceGatheringState);
92
+ };
93
+
94
+ // DataChannel for status updates
95
+ this.dataChannel = this.pc.createDataChannel('status');
96
+ this.dataChannel.onmessage = (event) => {
97
+ try {
98
+ const data = JSON.parse(event.data);
99
+ this.handleServerMessage(data);
100
+ } catch (e) {
101
+ console.error('Failed to parse data channel message:', e);
102
+ }
103
+ };
104
+
105
+ this.pc.ontrack = (event) => {
106
+ const stream = event.streams[0];
107
+ if (this.videoElement) {
108
+ this.videoElement.srcObject = stream;
109
+ this.videoElement.autoplay = true;
110
+ this.videoElement.playsInline = true;
111
+ this.videoElement.play().catch(() => {});
112
+ }
113
+ };
114
+
115
+ // Add local camera tracks
116
+ this.stream.getTracks().forEach((track) => {
117
+ this.pc.addTrack(track, this.stream);
118
+ });
119
+
120
+ const offer = await this.pc.createOffer();
121
+ await this.pc.setLocalDescription(offer);
122
+
123
+ // Wait for ICE gathering to complete so SDP includes candidates
124
+ await new Promise((resolve) => {
125
+ if (this.pc.iceGatheringState === 'complete') {
126
+ resolve();
127
+ return;
128
+ }
129
+ const onIce = () => {
130
+ if (this.pc.iceGatheringState === 'complete') {
131
+ this.pc.removeEventListener('icegatheringstatechange', onIce);
132
+ resolve();
133
+ }
134
+ };
135
+ this.pc.addEventListener('icegatheringstatechange', onIce);
136
+ });
137
+
138
+ console.log('📤 Sending offer to server...');
139
+ const response = await fetch('/api/webrtc/offer', {
140
+ method: 'POST',
141
+ headers: { 'Content-Type': 'application/json' },
142
+ body: JSON.stringify({
143
+ sdp: this.pc.localDescription.sdp,
144
+ type: this.pc.localDescription.type
145
+ })
146
+ });
147
+
148
+ if (!response.ok) {
149
+ const errorText = await response.text();
150
+ console.error('❌ Server error:', errorText);
151
+ throw new Error(`Server returned ${response.status}: ${errorText}`);
152
+ }
153
+
154
+ const answer = await response.json();
155
+ console.log('✅ Received answer from server, session_id:', answer.session_id);
156
+
157
+ await this.pc.setRemoteDescription(answer);
158
+ console.log('✅ Remote description set successfully');
159
+
160
+ this.sessionId = answer.session_id;
161
+ if (this.callbacks.onSessionStart) {
162
+ this.callbacks.onSessionStart(this.sessionId);
163
+ }
164
+
165
+ }
166
+
167
+ async requestNotificationPermission() {
168
+ if ('Notification' in window && Notification.permission === 'default') {
169
+ try {
170
+ await Notification.requestPermission();
171
+ } catch (error) {
172
+ console.error('Failed to request notification permission:', error);
173
+ }
174
+ }
175
+ }
176
+
177
+ async loadNotificationSettings() {
178
+ try {
179
+ const response = await fetch('/api/settings');
180
+ const settings = await response.json();
181
+ if (settings) {
182
+ this.notificationEnabled = settings.notification_enabled ?? true;
183
+ this.notificationThreshold = settings.notification_threshold ?? 30;
184
+ }
185
+ } catch (error) {
186
+ console.error('Failed to load notification settings:', error);
187
+ }
188
+ }
189
+
190
+ sendNotification(title, message) {
191
+ if (!this.notificationEnabled) return;
192
+ if ('Notification' in window && Notification.permission === 'granted') {
193
+ try {
194
+ const notification = new Notification(title, {
195
+ body: message,
196
+ icon: '/vite.svg',
197
+ badge: '/vite.svg',
198
+ tag: 'focus-guard-distraction',
199
+ requireInteraction: false
200
+ });
201
+
202
+ // 3秒后自动关闭
203
+ setTimeout(() => notification.close(), 3000);
204
+ } catch (error) {
205
+ console.error('Failed to send notification:', error);
206
+ }
207
+ }
208
+ }
209
+
210
+ handleServerMessage(data) {
211
+ switch (data.type) {
212
+ case 'detection':
213
+ this.updateStatus(data.focused);
214
+ this.latestDetectionData = {
215
+ detections: data.detections || [],
216
+ confidence: data.confidence || 0,
217
+ focused: data.focused,
218
+ timestamp: performance.now()
219
+ };
220
+ this.lastConfidence = data.confidence || 0;
221
+
222
+ if (this.callbacks.onStatusUpdate) {
223
+ this.callbacks.onStatusUpdate(this.currentStatus);
224
+ }
225
+ break;
226
+ default:
227
+ break;
228
+ }
229
+ }
230
+
231
+ updateStatus(newFocused) {
232
+ this.statusBuffer.push(newFocused);
233
+ if (this.statusBuffer.length > this.bufferSize) {
234
+ this.statusBuffer.shift();
235
+ }
236
+
237
+ if (this.statusBuffer.length < this.bufferSize) return false;
238
+
239
+ const focusedCount = this.statusBuffer.filter(f => f).length;
240
+ const focusedRatio = focusedCount / this.statusBuffer.length;
241
+
242
+ const previousStatus = this.currentStatus;
243
+
244
+ if (focusedRatio >= 0.75) {
245
+ this.currentStatus = true;
246
+ } else if (focusedRatio <= 0.25) {
247
+ this.currentStatus = false;
248
+ }
249
+
250
+ // 通知逻辑
251
+ this.handleNotificationLogic(previousStatus, this.currentStatus);
252
+ }
253
+
254
+ handleNotificationLogic(previousStatus, currentStatus) {
255
+ const now = Date.now();
256
+
257
+ // 如果从专注变为不专注,记录开始时间
258
+ if (previousStatus && !currentStatus) {
259
+ this.unfocusedStartTime = now;
260
+ }
261
+
262
+ // 如果从不专注变为专注,清除计时
263
+ if (!previousStatus && currentStatus) {
264
+ this.unfocusedStartTime = null;
265
+ }
266
+
267
+ // 如果持续不专注
268
+ if (!currentStatus && this.unfocusedStartTime) {
269
+ const unfocusedDuration = (now - this.unfocusedStartTime) / 1000; // 秒
270
+
271
+ // 检查是否超过阈值且不在冷却期
272
+ if (unfocusedDuration >= this.notificationThreshold) {
273
+ const canSendNotification = !this.lastNotificationTime ||
274
+ (now - this.lastNotificationTime) >= this.notificationCooldown;
275
+
276
+ if (canSendNotification) {
277
+ this.sendNotification(
278
+ '⚠️ Focus Alert',
279
+ `You've been distracted for ${Math.floor(unfocusedDuration)} seconds. Get back to work!`
280
+ );
281
+ this.lastNotificationTime = now;
282
+ }
283
+ }
284
+ }
285
+ }
286
+
287
+ async stopStreaming() {
288
+ this.isStreaming = false;
289
+
290
+ try {
291
+ if (document.pictureInPictureElement) {
292
+ await document.exitPictureInPicture();
293
+ }
294
+ if (this.videoElement && typeof this.videoElement.webkitSetPresentationMode === 'function') {
295
+ if (this.videoElement.webkitPresentationMode === 'picture-in-picture') {
296
+ this.videoElement.webkitSetPresentationMode('inline');
297
+ }
298
+ }
299
+ } catch (e) {
300
+ // ignore PiP exit errors
301
+ }
302
+
303
+ if (this.pc) {
304
+ try {
305
+ this.pc.getSenders().forEach(sender => sender.track && sender.track.stop());
306
+ this.pc.close();
307
+ } catch (e) {
308
+ console.error('Failed to close RTCPeerConnection:', e);
309
+ }
310
+ this.pc = null;
311
+ }
312
+
313
+ if (this.stream) {
314
+ this.stream.getTracks().forEach(track => track.stop());
315
+ this.stream = null;
316
+ }
317
+
318
+ if (this.videoElement) {
319
+ this.videoElement.srcObject = null;
320
+ }
321
+
322
+ if (this.sessionId) {
323
+ try {
324
+ const response = await fetch('/api/sessions/end', {
325
+ method: 'POST',
326
+ headers: { 'Content-Type': 'application/json' },
327
+ body: JSON.stringify({ session_id: this.sessionId })
328
+ });
329
+ const summary = await response.json();
330
+ if (this.callbacks.onSessionEnd) {
331
+ this.callbacks.onSessionEnd(summary);
332
+ }
333
+ } catch (e) {
334
+ console.error('Failed to end session:', e);
335
+ }
336
+ }
337
+
338
+ // 清理通知状态
339
+ this.unfocusedStartTime = null;
340
+ this.lastNotificationTime = null;
341
+ this.sessionId = null;
342
+ }
343
+
344
+ setFrameRate(rate) {
345
+ this.frameRate = Math.max(1, Math.min(60, rate));
346
+ if (this.stream) {
347
+ const videoTrack = this.stream.getVideoTracks()[0];
348
+ if (videoTrack && videoTrack.applyConstraints) {
349
+ videoTrack.applyConstraints({ frameRate: { ideal: this.frameRate, max: this.frameRate } }).catch(() => {});
350
+ }
351
+ }
352
+ }
353
+ }
src/utils/VideoManagerLocal.js ADDED
@@ -0,0 +1,497 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // src/utils/VideoManagerLocal.js
2
+ // 本地视频处理版本 - 使用 WebSocket + Canvas,不依赖 WebRTC
3
+
4
+ export class VideoManagerLocal {
5
+ constructor(callbacks) {
6
+ this.callbacks = callbacks || {};
7
+
8
+ this.localVideoElement = null; // 显示本地摄像头
9
+ this.displayVideoElement = null; // 显示处理后的视频
10
+ this.canvas = null;
11
+ this.stream = null;
12
+ this.ws = null;
13
+
14
+ this.isStreaming = false;
15
+ this.sessionId = null;
16
+ this.sessionStartTime = null;
17
+ this.frameRate = 15; // 降低帧率以减少网络负载
18
+ this.captureInterval = null;
19
+
20
+ // 状态平滑处理
21
+ this.currentStatus = false;
22
+ this.statusBuffer = [];
23
+ this.bufferSize = 3;
24
+
25
+ // 检测数据
26
+ this.latestDetectionData = null;
27
+ this.lastConfidence = 0;
28
+
29
+ // 通知系统
30
+ this.notificationEnabled = true;
31
+ this.notificationThreshold = 30;
32
+ this.unfocusedStartTime = null;
33
+ this.lastNotificationTime = null;
34
+ this.notificationCooldown = 60000;
35
+
36
+ // 性能统计
37
+ this.stats = {
38
+ framesSent: 0,
39
+ framesProcessed: 0,
40
+ avgLatency: 0,
41
+ lastLatencies: []
42
+ };
43
+ }
44
+
45
+ // 初始化摄像头
46
+ async initCamera(localVideoRef, displayCanvasRef) {
47
+ try {
48
+ console.log('Initializing local camera...');
49
+
50
+ this.stream = await navigator.mediaDevices.getUserMedia({
51
+ video: {
52
+ width: { ideal: 640 },
53
+ height: { ideal: 480 },
54
+ facingMode: 'user'
55
+ },
56
+ audio: false
57
+ });
58
+
59
+ this.localVideoElement = localVideoRef;
60
+ this.displayCanvas = displayCanvasRef;
61
+
62
+ // 显示本地视频流
63
+ if (this.localVideoElement) {
64
+ this.localVideoElement.srcObject = this.stream;
65
+ this.localVideoElement.play();
66
+ }
67
+
68
+ // 创建用于截图的 canvas
69
+ this.canvas = document.createElement('canvas');
70
+ this.canvas.width = 640;
71
+ this.canvas.height = 480;
72
+
73
+ console.log('Local camera initialized');
74
+ return true;
75
+ } catch (error) {
76
+ console.error('Camera init error:', error);
77
+ throw error;
78
+ }
79
+ }
80
+
81
+ // 开始流式处理
82
+ async startStreaming() {
83
+ if (!this.stream) {
84
+ throw new Error('Camera not initialized');
85
+ }
86
+
87
+ if (this.isStreaming) {
88
+ console.warn('Already streaming');
89
+ return;
90
+ }
91
+
92
+ console.log('Starting WebSocket streaming...');
93
+ this.isStreaming = true;
94
+
95
+ // 请求通知权限
96
+ await this.requestNotificationPermission();
97
+ await this.loadNotificationSettings();
98
+
99
+ // 建立 WebSocket 连接
100
+ await this.connectWebSocket();
101
+
102
+ // 开始定期截图并发送
103
+ this.startCapture();
104
+
105
+ console.log('Streaming started');
106
+ }
107
+
108
+ // 建立 WebSocket 连接
109
+ async connectWebSocket() {
110
+ return new Promise((resolve, reject) => {
111
+ const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
112
+ const wsUrl = `${protocol}//${window.location.host}/ws/video`;
113
+
114
+ console.log('Connecting to WebSocket:', wsUrl);
115
+
116
+ this.ws = new WebSocket(wsUrl);
117
+
118
+ this.ws.onopen = () => {
119
+ console.log('WebSocket connected');
120
+
121
+ // 发送开始会话请求
122
+ this.ws.send(JSON.stringify({ type: 'start_session' }));
123
+ resolve();
124
+ };
125
+
126
+ this.ws.onmessage = (event) => {
127
+ try {
128
+ const data = JSON.parse(event.data);
129
+ this.handleServerMessage(data);
130
+ } catch (e) {
131
+ console.error('Failed to parse message:', e);
132
+ }
133
+ };
134
+
135
+ this.ws.onerror = (error) => {
136
+ console.error('WebSocket error:', error);
137
+ reject(error);
138
+ };
139
+
140
+ this.ws.onclose = () => {
141
+ console.log('WebSocket disconnected');
142
+ if (this.isStreaming) {
143
+ console.log('Attempting to reconnect...');
144
+ setTimeout(() => this.connectWebSocket(), 2000);
145
+ }
146
+ };
147
+ });
148
+ }
149
+
150
+ // 开始截图并发送
151
+ startCapture() {
152
+ const interval = 1000 / this.frameRate; // 转换为毫秒
153
+
154
+ this.captureInterval = setInterval(() => {
155
+ if (!this.isStreaming || !this.ws || this.ws.readyState !== WebSocket.OPEN) {
156
+ return;
157
+ }
158
+
159
+ try {
160
+ // 从视频流截取一帧
161
+ const ctx = this.canvas.getContext('2d');
162
+ ctx.drawImage(this.localVideoElement, 0, 0, this.canvas.width, this.canvas.height);
163
+
164
+ // 转换为 JPEG base64(压缩质量 0.8)
165
+ const imageData = this.canvas.toDataURL('image/jpeg', 0.8);
166
+ const base64Data = imageData.split(',')[1];
167
+
168
+ // 发送到服务器
169
+ const timestamp = Date.now();
170
+ this.ws.send(JSON.stringify({
171
+ type: 'frame',
172
+ image: base64Data,
173
+ timestamp: timestamp
174
+ }));
175
+
176
+ this.stats.framesSent++;
177
+
178
+ } catch (error) {
179
+ console.error('Capture error:', error);
180
+ }
181
+ }, interval);
182
+
183
+ console.log(`Capturing at ${this.frameRate} FPS`);
184
+ }
185
+
186
+ // 处理服务器消息
187
+ handleServerMessage(data) {
188
+ switch (data.type) {
189
+ case 'session_started':
190
+ this.sessionId = data.session_id;
191
+ this.sessionStartTime = Date.now();
192
+ console.log('Session started:', this.sessionId);
193
+ if (this.callbacks.onSessionStart) {
194
+ this.callbacks.onSessionStart(this.sessionId);
195
+ }
196
+ break;
197
+
198
+ case 'detection':
199
+ this.stats.framesProcessed++;
200
+
201
+ // 计算延迟
202
+ if (data.timestamp) {
203
+ const latency = Date.now() - data.timestamp;
204
+ this.stats.lastLatencies.push(latency);
205
+ if (this.stats.lastLatencies.length > 10) {
206
+ this.stats.lastLatencies.shift();
207
+ }
208
+ this.stats.avgLatency =
209
+ this.stats.lastLatencies.reduce((a, b) => a + b, 0) /
210
+ this.stats.lastLatencies.length;
211
+ }
212
+
213
+ // 更新状态
214
+ this.updateStatus(data.focused);
215
+
216
+ this.latestDetectionData = {
217
+ detections: data.detections || [],
218
+ confidence: data.confidence || 0,
219
+ focused: data.focused,
220
+ timestamp: performance.now()
221
+ };
222
+
223
+ this.lastConfidence = data.confidence || 0;
224
+
225
+ if (this.callbacks.onStatusUpdate) {
226
+ this.callbacks.onStatusUpdate(this.currentStatus);
227
+ }
228
+
229
+ // 在 display canvas 上绘制结果
230
+ this.drawDetectionResult(data);
231
+ break;
232
+
233
+ case 'session_ended':
234
+ console.log('Received session_ended message');
235
+ console.log('Session summary:', data.summary);
236
+ if (this.callbacks.onSessionEnd) {
237
+ console.log('Calling onSessionEnd callback');
238
+ this.callbacks.onSessionEnd(data.summary);
239
+ } else {
240
+ console.warn('No onSessionEnd callback registered');
241
+ }
242
+ this.sessionId = null;
243
+ this.sessionStartTime = null;
244
+ break;
245
+
246
+ case 'error':
247
+ console.error('Server error:', data.message);
248
+ break;
249
+
250
+ default:
251
+ console.log('Unknown message type:', data.type);
252
+ }
253
+ }
254
+
255
+ // 在 canvas 上绘制检测结果
256
+ drawDetectionResult(data) {
257
+ if (!this.displayCanvas) return;
258
+
259
+ const ctx = this.displayCanvas.getContext('2d');
260
+
261
+ // 绘制当前帧
262
+ ctx.drawImage(this.localVideoElement, 0, 0, this.displayCanvas.width, this.displayCanvas.height);
263
+
264
+ // 绘制检测框
265
+ if (data.detections && data.detections.length > 0) {
266
+ data.detections.forEach(det => {
267
+ const [x1, y1, x2, y2] = det.bbox;
268
+ const color = data.focused ? '#00FF00' : '#FF0000';
269
+
270
+ // 绘制边框
271
+ ctx.strokeStyle = color;
272
+ ctx.lineWidth = 3;
273
+ ctx.strokeRect(x1, y1, x2 - x1, y2 - y1);
274
+
275
+ // 绘制标签
276
+ const label = `${det.class_name || 'person'} ${(det.confidence * 100).toFixed(1)}%`;
277
+ ctx.fillStyle = color;
278
+ ctx.font = '16px Arial';
279
+ ctx.fillText(label, x1, Math.max(20, y1 - 5));
280
+ });
281
+ }
282
+
283
+ // 绘制状态文字
284
+ const statusText = data.focused ? 'FOCUSED' : 'NOT FOCUSED';
285
+ const color = data.focused ? '#00FF00' : '#FF0000';
286
+ ctx.fillStyle = color;
287
+ ctx.font = 'bold 24px Arial';
288
+ ctx.fillText(statusText, 10, 30);
289
+
290
+ ctx.font = '16px Arial';
291
+ ctx.fillText(`Confidence: ${(data.confidence * 100).toFixed(1)}%`, 10, 55);
292
+
293
+ // 显示性能统计
294
+ ctx.font = '12px Arial';
295
+ ctx.fillStyle = '#FFFFFF';
296
+ ctx.fillText(`FPS: ${this.frameRate} | Latency: ${this.stats.avgLatency.toFixed(0)}ms`, 10, this.displayCanvas.height - 10);
297
+ }
298
+
299
+ updateStatus(newFocused) {
300
+ this.statusBuffer.push(newFocused);
301
+ if (this.statusBuffer.length > this.bufferSize) {
302
+ this.statusBuffer.shift();
303
+ }
304
+
305
+ if (this.statusBuffer.length < this.bufferSize) return false;
306
+
307
+ const focusedCount = this.statusBuffer.filter(f => f).length;
308
+ const focusedRatio = focusedCount / this.statusBuffer.length;
309
+
310
+ const previousStatus = this.currentStatus;
311
+
312
+ if (focusedRatio >= 0.75) {
313
+ this.currentStatus = true;
314
+ } else if (focusedRatio <= 0.25) {
315
+ this.currentStatus = false;
316
+ }
317
+
318
+ this.handleNotificationLogic(previousStatus, this.currentStatus);
319
+ }
320
+
321
+ handleNotificationLogic(previousStatus, currentStatus) {
322
+ const now = Date.now();
323
+
324
+ if (previousStatus && !currentStatus) {
325
+ this.unfocusedStartTime = now;
326
+ }
327
+
328
+ if (!previousStatus && currentStatus) {
329
+ this.unfocusedStartTime = null;
330
+ }
331
+
332
+ if (!currentStatus && this.unfocusedStartTime) {
333
+ const unfocusedDuration = (now - this.unfocusedStartTime) / 1000;
334
+
335
+ if (unfocusedDuration >= this.notificationThreshold) {
336
+ const canSendNotification = !this.lastNotificationTime ||
337
+ (now - this.lastNotificationTime) >= this.notificationCooldown;
338
+
339
+ if (canSendNotification) {
340
+ this.sendNotification(
341
+ 'Focus Alert',
342
+ `You've been distracted for ${Math.floor(unfocusedDuration)} seconds. Get back to work!`
343
+ );
344
+ this.lastNotificationTime = now;
345
+ }
346
+ }
347
+ }
348
+ }
349
+
350
+ async requestNotificationPermission() {
351
+ if ('Notification' in window && Notification.permission === 'default') {
352
+ try {
353
+ await Notification.requestPermission();
354
+ } catch (error) {
355
+ console.error('Failed to request notification permission:', error);
356
+ }
357
+ }
358
+ }
359
+
360
+ async loadNotificationSettings() {
361
+ try {
362
+ const response = await fetch('/api/settings');
363
+ const settings = await response.json();
364
+ if (settings) {
365
+ this.notificationEnabled = settings.notification_enabled ?? true;
366
+ this.notificationThreshold = settings.notification_threshold ?? 30;
367
+ }
368
+ } catch (error) {
369
+ console.error('Failed to load notification settings:', error);
370
+ }
371
+ }
372
+
373
+ sendNotification(title, message) {
374
+ if (!this.notificationEnabled) return;
375
+ if ('Notification' in window && Notification.permission === 'granted') {
376
+ try {
377
+ const notification = new Notification(title, {
378
+ body: message,
379
+ icon: '/vite.svg',
380
+ badge: '/vite.svg',
381
+ tag: 'focus-guard-distraction',
382
+ requireInteraction: false
383
+ });
384
+ setTimeout(() => notification.close(), 3000);
385
+ } catch (error) {
386
+ console.error('Failed to send notification:', error);
387
+ }
388
+ }
389
+ }
390
+
391
+ async stopStreaming() {
392
+ console.log('Stopping streaming...');
393
+
394
+ this.isStreaming = false;
395
+
396
+ // 停止截图
397
+ if (this.captureInterval) {
398
+ clearInterval(this.captureInterval);
399
+ this.captureInterval = null;
400
+ }
401
+
402
+ // 发送结束会话请求并等待响应
403
+ if (this.ws && this.ws.readyState === WebSocket.OPEN && this.sessionId) {
404
+ const sessionId = this.sessionId;
405
+
406
+ // 等待 session_ended 消息
407
+ const waitForSessionEnd = new Promise((resolve) => {
408
+ const originalHandler = this.ws.onmessage;
409
+ const timeout = setTimeout(() => {
410
+ this.ws.onmessage = originalHandler;
411
+ console.log('Session end timeout, proceeding anyway');
412
+ resolve();
413
+ }, 2000);
414
+
415
+ this.ws.onmessage = (event) => {
416
+ try {
417
+ const data = JSON.parse(event.data);
418
+ if (data.type === 'session_ended') {
419
+ clearTimeout(timeout);
420
+ this.handleServerMessage(data);
421
+ this.ws.onmessage = originalHandler;
422
+ resolve();
423
+ } else {
424
+ // 仍然处理其他消息
425
+ this.handleServerMessage(data);
426
+ }
427
+ } catch (e) {
428
+ console.error('Failed to parse message:', e);
429
+ }
430
+ };
431
+ });
432
+
433
+ console.log('Sending end_session request for session:', sessionId);
434
+ this.ws.send(JSON.stringify({
435
+ type: 'end_session',
436
+ session_id: sessionId
437
+ }));
438
+
439
+ // 等待响应或超时
440
+ await waitForSessionEnd;
441
+ }
442
+
443
+ // 延迟关闭 WebSocket 确保消息发送完成
444
+ await new Promise(resolve => setTimeout(resolve, 200));
445
+
446
+ // 关闭 WebSocket
447
+ if (this.ws) {
448
+ this.ws.close();
449
+ this.ws = null;
450
+ }
451
+
452
+ // 停止摄像头
453
+ if (this.stream) {
454
+ this.stream.getTracks().forEach(track => track.stop());
455
+ this.stream = null;
456
+ }
457
+
458
+ // 清空视频
459
+ if (this.localVideoElement) {
460
+ this.localVideoElement.srcObject = null;
461
+ }
462
+
463
+ // 清空 canvas
464
+ if (this.displayCanvas) {
465
+ const ctx = this.displayCanvas.getContext('2d');
466
+ ctx.clearRect(0, 0, this.displayCanvas.width, this.displayCanvas.height);
467
+ }
468
+
469
+ // 清理状态
470
+ this.unfocusedStartTime = null;
471
+ this.lastNotificationTime = null;
472
+
473
+ console.log('Streaming stopped');
474
+ console.log('Stats:', this.stats);
475
+ }
476
+
477
+ setFrameRate(rate) {
478
+ this.frameRate = Math.max(1, Math.min(30, rate));
479
+ console.log(`Frame rate set to ${this.frameRate} FPS`);
480
+
481
+ // 重启截图(如果正在运行)
482
+ if (this.isStreaming && this.captureInterval) {
483
+ clearInterval(this.captureInterval);
484
+ this.startCapture();
485
+ }
486
+ }
487
+
488
+ getStats() {
489
+ return {
490
+ ...this.stats,
491
+ isStreaming: this.isStreaming,
492
+ sessionId: this.sessionId,
493
+ currentStatus: this.currentStatus,
494
+ lastConfidence: this.lastConfidence
495
+ };
496
+ }
497
+ }
ui/live_demo.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import sys
4
+ import time
5
+
6
+ import cv2
7
+ import numpy as np
8
+ from mediapipe.tasks.python.vision import FaceLandmarksConnections
9
+
10
+ _PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
11
+ if _PROJECT_ROOT not in sys.path:
12
+ sys.path.insert(0, _PROJECT_ROOT)
13
+
14
+ from ui.pipeline import (
15
+ FaceMeshPipeline, MLPPipeline, HybridFocusPipeline, GRUPipeline,
16
+ _load_gru_artifacts, _latest_model_artifacts,
17
+ )
18
+ from models.face_mesh import FaceMeshDetector
19
+
20
+ FONT = cv2.FONT_HERSHEY_SIMPLEX
21
+ CYAN = (255, 255, 0)
22
+ GREEN = (0, 255, 0)
23
+ MAGENTA = (255, 0, 255)
24
+ ORANGE = (0, 165, 255)
25
+ RED = (0, 0, 255)
26
+ WHITE = (255, 255, 255)
27
+ YELLOW = (0, 255, 255)
28
+ LIGHT_GREEN = (144, 238, 144)
29
+
30
+ _TESSELATION = [(c.start, c.end) for c in FaceLandmarksConnections.FACE_LANDMARKS_TESSELATION]
31
+ _CONTOURS = [(c.start, c.end) for c in FaceLandmarksConnections.FACE_LANDMARKS_CONTOURS]
32
+ _LEFT_EYEBROW = [70, 63, 105, 66, 107, 55, 65, 52, 53, 46]
33
+ _RIGHT_EYEBROW = [300, 293, 334, 296, 336, 285, 295, 282, 283, 276]
34
+ _NOSE_BRIDGE = [6, 197, 195, 5, 4, 1, 19, 94, 2]
35
+ _LIPS_OUTER = [61, 146, 91, 181, 84, 17, 314, 405, 321, 375, 291, 409, 270, 269, 267, 0, 37, 39, 40, 185, 61]
36
+ _LIPS_INNER = [78, 95, 88, 178, 87, 14, 317, 402, 318, 324, 308, 415, 310, 311, 312, 13, 82, 81, 80, 191, 78]
37
+ _LEFT_EAR_POINTS = [33, 160, 158, 133, 153, 145]
38
+ _RIGHT_EAR_POINTS = [362, 385, 387, 263, 373, 380]
39
+
40
+ MESH_FULL = 0
41
+ MESH_CONTOURS = 1
42
+ MESH_OFF = 2
43
+ _MESH_NAMES = ["FULL MESH", "CONTOURS", "MESH OFF"]
44
+
45
+ MODE_GEO = 0
46
+ MODE_MLP = 1
47
+ MODE_GRU = 2
48
+ MODE_HYBRID = 3
49
+ _MODE_NAMES = ["GEOMETRIC", "MLP", "GRU", "HYBRID"]
50
+ _MODE_KEYS = {ord("1"): MODE_GEO, ord("2"): MODE_MLP, ord("3"): MODE_GRU, ord("4"): MODE_HYBRID}
51
+
52
+
53
+ def _lm_to_px(landmarks, idx, w, h):
54
+ return (int(landmarks[idx, 0] * w), int(landmarks[idx, 1] * h))
55
+
56
+
57
+ def draw_tessellation(frame, landmarks, w, h):
58
+ overlay = frame.copy()
59
+ for conn in _TESSELATION:
60
+ pt1 = _lm_to_px(landmarks, conn[0], w, h)
61
+ pt2 = _lm_to_px(landmarks, conn[1], w, h)
62
+ cv2.line(overlay, pt1, pt2, (200, 200, 200), 1, cv2.LINE_AA)
63
+ cv2.addWeighted(overlay, 0.3, frame, 0.7, 0, frame)
64
+
65
+
66
+ def draw_contours(frame, landmarks, w, h):
67
+ for conn in _CONTOURS:
68
+ pt1 = _lm_to_px(landmarks, conn[0], w, h)
69
+ pt2 = _lm_to_px(landmarks, conn[1], w, h)
70
+ cv2.line(frame, pt1, pt2, CYAN, 1, cv2.LINE_AA)
71
+ for indices in [_LEFT_EYEBROW, _RIGHT_EYEBROW]:
72
+ for i in range(len(indices) - 1):
73
+ pt1 = _lm_to_px(landmarks, indices[i], w, h)
74
+ pt2 = _lm_to_px(landmarks, indices[i + 1], w, h)
75
+ cv2.line(frame, pt1, pt2, LIGHT_GREEN, 2, cv2.LINE_AA)
76
+ for i in range(len(_NOSE_BRIDGE) - 1):
77
+ pt1 = _lm_to_px(landmarks, _NOSE_BRIDGE[i], w, h)
78
+ pt2 = _lm_to_px(landmarks, _NOSE_BRIDGE[i + 1], w, h)
79
+ cv2.line(frame, pt1, pt2, ORANGE, 1, cv2.LINE_AA)
80
+ for i in range(len(_LIPS_OUTER) - 1):
81
+ pt1 = _lm_to_px(landmarks, _LIPS_OUTER[i], w, h)
82
+ pt2 = _lm_to_px(landmarks, _LIPS_OUTER[i + 1], w, h)
83
+ cv2.line(frame, pt1, pt2, MAGENTA, 1, cv2.LINE_AA)
84
+ for i in range(len(_LIPS_INNER) - 1):
85
+ pt1 = _lm_to_px(landmarks, _LIPS_INNER[i], w, h)
86
+ pt2 = _lm_to_px(landmarks, _LIPS_INNER[i + 1], w, h)
87
+ cv2.line(frame, pt1, pt2, (200, 0, 200), 1, cv2.LINE_AA)
88
+
89
+
90
+ def draw_eyes_and_irises(frame, landmarks, w, h):
91
+ left_pts = np.array(
92
+ [_lm_to_px(landmarks, i, w, h) for i in FaceMeshDetector.LEFT_EYE_INDICES],
93
+ dtype=np.int32,
94
+ )
95
+ cv2.polylines(frame, [left_pts], True, GREEN, 2, cv2.LINE_AA)
96
+ right_pts = np.array(
97
+ [_lm_to_px(landmarks, i, w, h) for i in FaceMeshDetector.RIGHT_EYE_INDICES],
98
+ dtype=np.int32,
99
+ )
100
+ cv2.polylines(frame, [right_pts], True, GREEN, 2, cv2.LINE_AA)
101
+ for indices in [_LEFT_EAR_POINTS, _RIGHT_EAR_POINTS]:
102
+ for idx in indices:
103
+ pt = _lm_to_px(landmarks, idx, w, h)
104
+ cv2.circle(frame, pt, 3, YELLOW, -1, cv2.LINE_AA)
105
+ for iris_indices, eye_inner, eye_outer in [
106
+ (FaceMeshDetector.LEFT_IRIS_INDICES, 133, 33),
107
+ (FaceMeshDetector.RIGHT_IRIS_INDICES, 362, 263),
108
+ ]:
109
+ iris_pts = np.array(
110
+ [_lm_to_px(landmarks, i, w, h) for i in iris_indices],
111
+ dtype=np.int32,
112
+ )
113
+ center = iris_pts[0]
114
+ if len(iris_pts) >= 5:
115
+ radii = [np.linalg.norm(iris_pts[j] - center) for j in range(1, 5)]
116
+ radius = max(int(np.mean(radii)), 2)
117
+ cv2.circle(frame, tuple(center), radius, MAGENTA, 2, cv2.LINE_AA)
118
+ cv2.circle(frame, tuple(center), 2, WHITE, -1, cv2.LINE_AA)
119
+ eye_center_x = (landmarks[eye_inner, 0] + landmarks[eye_outer, 0]) / 2.0
120
+ eye_center_y = (landmarks[eye_inner, 1] + landmarks[eye_outer, 1]) / 2.0
121
+ eye_center = (int(eye_center_x * w), int(eye_center_y * h))
122
+ dx = center[0] - eye_center[0]
123
+ dy = center[1] - eye_center[1]
124
+ gaze_end = (int(center[0] + dx * 3), int(center[1] + dy * 3))
125
+ cv2.line(frame, tuple(center), gaze_end, RED, 1, cv2.LINE_AA)
126
+
127
+
128
+ def main():
129
+ parser = argparse.ArgumentParser()
130
+ parser.add_argument("--camera", type=int, default=0)
131
+ parser.add_argument("--mlp-dir", type=str, default=None)
132
+ parser.add_argument("--max-angle", type=float, default=22.0)
133
+ parser.add_argument("--eye-model", type=str, default=None)
134
+ parser.add_argument("--eye-backend", type=str, default="yolo", choices=["yolo", "geometric"])
135
+ parser.add_argument("--eye-blend", type=float, default=0.5)
136
+ args = parser.parse_args()
137
+
138
+ model_dir = args.mlp_dir or os.path.join(_PROJECT_ROOT, "checkpoints")
139
+
140
+ detector = FaceMeshDetector()
141
+ pipelines = {}
142
+ available_modes = []
143
+
144
+ pipelines[MODE_GEO] = FaceMeshPipeline(
145
+ max_angle=args.max_angle,
146
+ eye_model_path=args.eye_model,
147
+ eye_backend=args.eye_backend,
148
+ eye_blend=args.eye_blend,
149
+ detector=detector,
150
+ )
151
+ available_modes.append(MODE_GEO)
152
+
153
+ mlp_path, _, _ = _latest_model_artifacts(model_dir)
154
+ if mlp_path is not None:
155
+ try:
156
+ pipelines[MODE_MLP] = MLPPipeline(model_dir=model_dir, detector=detector)
157
+ available_modes.append(MODE_MLP)
158
+ except Exception as e:
159
+ print(f"[DEMO] MLP unavailable: {e}")
160
+
161
+ try:
162
+ pipelines[MODE_HYBRID] = HybridFocusPipeline(
163
+ model_dir=model_dir,
164
+ eye_model_path=args.eye_model,
165
+ eye_backend=args.eye_backend,
166
+ eye_blend=args.eye_blend,
167
+ max_angle=args.max_angle,
168
+ detector=detector,
169
+ )
170
+ available_modes.append(MODE_HYBRID)
171
+ except Exception as e:
172
+ print(f"[DEMO] Hybrid unavailable: {e}")
173
+
174
+ gru_arts = _load_gru_artifacts(model_dir)
175
+ if gru_arts[0] is not None:
176
+ try:
177
+ pipelines[MODE_GRU] = GRUPipeline(model_dir=model_dir, detector=detector)
178
+ available_modes.append(MODE_GRU)
179
+ except Exception as e:
180
+ print(f"[DEMO] GRU unavailable: {e}")
181
+
182
+ current_mode = available_modes[0]
183
+ pipeline = pipelines[current_mode]
184
+
185
+ cap = cv2.VideoCapture(args.camera)
186
+ if not cap.isOpened():
187
+ print("[DEMO] ERROR: Cannot open camera")
188
+ return
189
+
190
+ mode_hint = " ".join(f"{k+1}:{_MODE_NAMES[k]}" for k in available_modes)
191
+ print(f"[DEMO] Available modes: {mode_hint}")
192
+ print(f"[DEMO] Active: {_MODE_NAMES[current_mode]}")
193
+ print("[DEMO] q=quit m=mesh 1-4=switch mode")
194
+
195
+ prev_time = time.time()
196
+ fps = 0.0
197
+ mesh_mode = MESH_FULL
198
+
199
+ try:
200
+ while True:
201
+ ret, frame = cap.read()
202
+ if not ret:
203
+ break
204
+
205
+ result = pipeline.process_frame(frame)
206
+ now = time.time()
207
+ fps = 0.9 * fps + 0.1 * (1.0 / max(now - prev_time, 1e-6))
208
+ prev_time = now
209
+
210
+ h, w = frame.shape[:2]
211
+ lm = result["landmarks"]
212
+ if lm is not None:
213
+ if mesh_mode == MESH_FULL:
214
+ draw_tessellation(frame, lm, w, h)
215
+ draw_contours(frame, lm, w, h)
216
+ elif mesh_mode == MESH_CONTOURS:
217
+ draw_contours(frame, lm, w, h)
218
+ draw_eyes_and_irises(frame, lm, w, h)
219
+ if hasattr(pipeline, "head_pose"):
220
+ pipeline.head_pose.draw_axes(frame, lm)
221
+ if result.get("left_bbox") and result.get("right_bbox"):
222
+ lx1, ly1, lx2, ly2 = result["left_bbox"]
223
+ rx1, ry1, rx2, ry2 = result["right_bbox"]
224
+ cv2.rectangle(frame, (lx1, ly1), (lx2, ly2), YELLOW, 1)
225
+ cv2.rectangle(frame, (rx1, ry1), (rx2, ry2), YELLOW, 1)
226
+
227
+ # --- HUD ---
228
+ status = "FOCUSED" if result["is_focused"] else "NOT FOCUSED"
229
+ status_color = GREEN if result["is_focused"] else RED
230
+ cv2.rectangle(frame, (0, 0), (w, 55), (0, 0, 0), -1)
231
+ cv2.putText(frame, status, (10, 28), FONT, 0.8, status_color, 2, cv2.LINE_AA)
232
+
233
+ mode_label = _MODE_NAMES[current_mode]
234
+ cv2.putText(frame, f"{mode_label} {_MESH_NAMES[mesh_mode]} FPS:{fps:.0f}",
235
+ (w - 340, 28), FONT, 0.45, WHITE, 1, cv2.LINE_AA)
236
+
237
+ detail = ""
238
+ if current_mode == MODE_GEO:
239
+ sf = result.get("s_face", 0)
240
+ se = result.get("s_eye", 0)
241
+ rs = result.get("raw_score", 0)
242
+ mar_s = f" MAR:{result['mar']:.2f}" if result.get("mar") is not None else ""
243
+ detail = f"S_face:{sf:.2f} S_eye:{se:.2f}{mar_s} score:{rs:.2f}"
244
+ elif current_mode == MODE_MLP:
245
+ mp = result.get("mlp_prob", 0)
246
+ rs = result.get("raw_score", 0)
247
+ mar_s = f" MAR:{result['mar']:.2f}" if result.get("mar") is not None else ""
248
+ detail = f"mlp_prob:{mp:.2f} score:{rs:.2f}{mar_s}"
249
+ elif current_mode == MODE_GRU:
250
+ gp = result.get("gru_prob", 0)
251
+ rs = result.get("raw_score", 0)
252
+ mar_s = f" MAR:{result['mar']:.2f}" if result.get("mar") is not None else ""
253
+ detail = f"gru_prob:{gp:.2f} score:{rs:.2f}{mar_s}"
254
+ elif current_mode == MODE_HYBRID:
255
+ mp = result.get("mlp_prob", 0)
256
+ gs = result.get("geo_score", 0)
257
+ fs = result.get("focus_score", 0)
258
+ mar_s = f" MAR:{result['mar']:.2f}" if result.get("mar") is not None else ""
259
+ detail = f"focus:{fs:.2f} mlp:{mp:.2f} geo:{gs:.2f}{mar_s}"
260
+
261
+ cv2.putText(frame, detail, (10, 48), FONT, 0.45, WHITE, 1, cv2.LINE_AA)
262
+
263
+ if result.get("is_yawning"):
264
+ cv2.putText(frame, "YAWN", (10, 75), FONT, 0.7, ORANGE, 2, cv2.LINE_AA)
265
+
266
+ if result.get("yaw") is not None:
267
+ cv2.putText(
268
+ frame,
269
+ f"yaw:{result['yaw']:+.0f} pitch:{result['pitch']:+.0f} roll:{result['roll']:+.0f}",
270
+ (w - 280, 48), FONT, 0.4, (180, 180, 180), 1, cv2.LINE_AA,
271
+ )
272
+
273
+ cv2.putText(frame, f"q:quit m:mesh {mode_hint}",
274
+ (10, h - 10), FONT, 0.35, (150, 150, 150), 1, cv2.LINE_AA)
275
+
276
+ cv2.imshow("FocusGuard", frame)
277
+
278
+ key = cv2.waitKey(1) & 0xFF
279
+ if key == ord("q"):
280
+ break
281
+ elif key == ord("m"):
282
+ mesh_mode = (mesh_mode + 1) % 3
283
+ print(f"[DEMO] Mesh: {_MESH_NAMES[mesh_mode]}")
284
+ elif key in _MODE_KEYS:
285
+ requested = _MODE_KEYS[key]
286
+ if requested in pipelines:
287
+ current_mode = requested
288
+ pipeline = pipelines[current_mode]
289
+ print(f"[DEMO] Switched to {_MODE_NAMES[current_mode]}")
290
+ else:
291
+ print(f"[DEMO] {_MODE_NAMES[requested]} not available (no checkpoint)")
292
+
293
+ finally:
294
+ cap.release()
295
+ cv2.destroyAllWindows()
296
+ for p in pipelines.values():
297
+ p.close()
298
+ detector.close()
299
+ print("[DEMO] Done")
300
+
301
+
302
+ if __name__ == "__main__":
303
+ main()
ui/pipeline.py ADDED
@@ -0,0 +1,700 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import glob
3
+ import json
4
+ import math
5
+ import os
6
+ import sys
7
+
8
+ import numpy as np
9
+ import joblib
10
+
11
+ _PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
12
+ if _PROJECT_ROOT not in sys.path:
13
+ sys.path.insert(0, _PROJECT_ROOT)
14
+
15
+ from models.face_mesh import FaceMeshDetector
16
+ from models.head_pose import HeadPoseEstimator
17
+ from models.eye_scorer import EyeBehaviourScorer, compute_mar, MAR_YAWN_THRESHOLD
18
+ from models.eye_crop import extract_eye_crops
19
+ from models.eye_classifier import load_eye_classifier, GeometricOnlyClassifier
20
+ from models.collect_features import FEATURE_NAMES, TemporalTracker, extract_features
21
+
22
+ _FEAT_IDX = {name: i for i, name in enumerate(FEATURE_NAMES)}
23
+
24
+
25
+ def _clip_features(vec):
26
+ """Clip raw features to the same ranges used during training."""
27
+ out = vec.copy()
28
+ _i = _FEAT_IDX
29
+
30
+ out[_i["yaw"]] = np.clip(out[_i["yaw"]], -45, 45)
31
+ out[_i["pitch"]] = np.clip(out[_i["pitch"]], -30, 30)
32
+ out[_i["roll"]] = np.clip(out[_i["roll"]], -30, 30)
33
+
34
+ out[_i["head_deviation"]] = math.sqrt(
35
+ float(out[_i["yaw"]]) ** 2 + float(out[_i["pitch"]]) ** 2
36
+ )
37
+
38
+ for f in ("ear_left", "ear_right", "ear_avg"):
39
+ out[_i[f]] = np.clip(out[_i[f]], 0, 0.85)
40
+
41
+ out[_i["mar"]] = np.clip(out[_i["mar"]], 0, 1.0)
42
+ out[_i["gaze_offset"]] = np.clip(out[_i["gaze_offset"]], 0, 0.50)
43
+ out[_i["perclos"]] = np.clip(out[_i["perclos"]], 0, 0.80)
44
+ out[_i["blink_rate"]] = np.clip(out[_i["blink_rate"]], 0, 30.0)
45
+ out[_i["closure_duration"]] = np.clip(out[_i["closure_duration"]], 0, 10.0)
46
+ out[_i["yawn_duration"]] = np.clip(out[_i["yawn_duration"]], 0, 10.0)
47
+
48
+ return out
49
+
50
+
51
+ class _OutputSmoother:
52
+ """EMA smoothing on focus score with no-face grace period."""
53
+
54
+ def __init__(self, alpha: float = 0.3, grace_frames: int = 15):
55
+ self._alpha = alpha
56
+ self._grace = grace_frames
57
+ self._score = 0.5
58
+ self._no_face = 0
59
+
60
+ def update(self, raw_score: float, face_detected: bool) -> float:
61
+ if face_detected:
62
+ self._no_face = 0
63
+ self._score += self._alpha * (raw_score - self._score)
64
+ else:
65
+ self._no_face += 1
66
+ if self._no_face > self._grace:
67
+ self._score *= 0.85
68
+ return self._score
69
+
70
+
71
+ DEFAULT_HYBRID_CONFIG = {
72
+ "w_mlp": 0.7,
73
+ "w_geo": 0.3,
74
+ "threshold": 0.55,
75
+ "use_yawn_veto": True,
76
+ "geo_face_weight": 0.4,
77
+ "geo_eye_weight": 0.6,
78
+ "mar_yawn_threshold": float(MAR_YAWN_THRESHOLD),
79
+ }
80
+
81
+
82
+ class _RuntimeFeatureEngine:
83
+ """Runtime feature engineering (magnitudes, velocities, variances) with EMA baselines."""
84
+
85
+ _MAG_FEATURES = ["pitch", "yaw", "head_deviation", "gaze_offset", "v_gaze", "h_gaze"]
86
+ _VEL_FEATURES = ["pitch", "yaw", "h_gaze", "v_gaze", "head_deviation", "gaze_offset"]
87
+ _VAR_FEATURES = ["h_gaze", "v_gaze", "pitch"]
88
+ _VAR_WINDOW = 30
89
+ _WARMUP = 15
90
+
91
+ def __init__(self, base_feature_names, norm_features=None):
92
+ self._base_names = list(base_feature_names)
93
+ self._norm_features = list(norm_features) if norm_features else []
94
+
95
+ tracked = set(self._MAG_FEATURES) | set(self._norm_features)
96
+ self._ema_mean = {f: 0.0 for f in tracked}
97
+ self._ema_var = {f: 1.0 for f in tracked}
98
+ self._n = 0
99
+ self._prev = None
100
+ self._var_bufs = {
101
+ f: collections.deque(maxlen=self._VAR_WINDOW) for f in self._VAR_FEATURES
102
+ }
103
+
104
+ self._ext_names = (
105
+ list(self._base_names)
106
+ + [f"{f}_mag" for f in self._MAG_FEATURES]
107
+ + [f"{f}_vel" for f in self._VEL_FEATURES]
108
+ + [f"{f}_var" for f in self._VAR_FEATURES]
109
+ )
110
+
111
+ @property
112
+ def extended_names(self):
113
+ return list(self._ext_names)
114
+
115
+ def transform(self, base_vec):
116
+ self._n += 1
117
+ raw = {name: float(base_vec[i]) for i, name in enumerate(self._base_names)}
118
+
119
+ alpha = 2.0 / (min(self._n, 120) + 1)
120
+ for feat in self._ema_mean:
121
+ if feat not in raw:
122
+ continue
123
+ v = raw[feat]
124
+ if self._n == 1:
125
+ self._ema_mean[feat] = v
126
+ self._ema_var[feat] = 0.0
127
+ else:
128
+ self._ema_mean[feat] += alpha * (v - self._ema_mean[feat])
129
+ self._ema_var[feat] += alpha * (
130
+ (v - self._ema_mean[feat]) ** 2 - self._ema_var[feat]
131
+ )
132
+
133
+ out = base_vec.copy().astype(np.float32)
134
+ if self._n > self._WARMUP:
135
+ for feat in self._norm_features:
136
+ if feat in raw:
137
+ idx = self._base_names.index(feat)
138
+ std = max(math.sqrt(self._ema_var[feat]), 1e-6)
139
+ out[idx] = (raw[feat] - self._ema_mean[feat]) / std
140
+
141
+ mag = np.zeros(len(self._MAG_FEATURES), dtype=np.float32)
142
+ for i, feat in enumerate(self._MAG_FEATURES):
143
+ if feat in raw:
144
+ mag[i] = abs(raw[feat] - self._ema_mean.get(feat, raw[feat]))
145
+
146
+ vel = np.zeros(len(self._VEL_FEATURES), dtype=np.float32)
147
+ if self._prev is not None:
148
+ for i, feat in enumerate(self._VEL_FEATURES):
149
+ if feat in raw and feat in self._prev:
150
+ vel[i] = abs(raw[feat] - self._prev[feat])
151
+ self._prev = dict(raw)
152
+
153
+ for feat in self._VAR_FEATURES:
154
+ if feat in raw:
155
+ self._var_bufs[feat].append(raw[feat])
156
+ var = np.zeros(len(self._VAR_FEATURES), dtype=np.float32)
157
+ for i, feat in enumerate(self._VAR_FEATURES):
158
+ buf = self._var_bufs[feat]
159
+ if len(buf) >= 2:
160
+ arr = np.array(buf)
161
+ var[i] = float(arr.var())
162
+
163
+ return np.concatenate([out, mag, vel, var])
164
+
165
+
166
+ class FaceMeshPipeline:
167
+ def __init__(
168
+ self,
169
+ max_angle: float = 22.0,
170
+ alpha: float = 0.4,
171
+ beta: float = 0.6,
172
+ threshold: float = 0.55,
173
+ eye_model_path: str | None = None,
174
+ eye_backend: str = "yolo",
175
+ eye_blend: float = 0.5,
176
+ detector=None,
177
+ ):
178
+ self.detector = detector or FaceMeshDetector()
179
+ self._owns_detector = detector is None
180
+ self.head_pose = HeadPoseEstimator(max_angle=max_angle)
181
+ self.eye_scorer = EyeBehaviourScorer()
182
+ self.alpha = alpha
183
+ self.beta = beta
184
+ self.threshold = threshold
185
+ self.eye_blend = eye_blend
186
+
187
+ self.eye_classifier = load_eye_classifier(
188
+ path=eye_model_path if eye_model_path and os.path.exists(eye_model_path) else None,
189
+ backend=eye_backend,
190
+ device="cpu",
191
+ )
192
+ self._has_eye_model = not isinstance(self.eye_classifier, GeometricOnlyClassifier)
193
+ if self._has_eye_model:
194
+ print(f"[PIPELINE] Eye model: {self.eye_classifier.name}")
195
+ self._smoother = _OutputSmoother()
196
+
197
+ def process_frame(self, bgr_frame: np.ndarray) -> dict:
198
+ landmarks = self.detector.process(bgr_frame)
199
+ h, w = bgr_frame.shape[:2]
200
+
201
+ out = {
202
+ "landmarks": landmarks,
203
+ "s_face": 0.0,
204
+ "s_eye": 0.0,
205
+ "raw_score": 0.0,
206
+ "is_focused": False,
207
+ "yaw": None,
208
+ "pitch": None,
209
+ "roll": None,
210
+ "mar": None,
211
+ "is_yawning": False,
212
+ "left_bbox": None,
213
+ "right_bbox": None,
214
+ }
215
+
216
+ if landmarks is None:
217
+ smoothed = self._smoother.update(0.0, False)
218
+ out["raw_score"] = smoothed
219
+ out["is_focused"] = smoothed >= self.threshold
220
+ return out
221
+
222
+ angles = self.head_pose.estimate(landmarks, w, h)
223
+ if angles is not None:
224
+ out["yaw"], out["pitch"], out["roll"] = angles
225
+ out["s_face"] = self.head_pose.score(landmarks, w, h)
226
+
227
+ s_eye_geo = self.eye_scorer.score(landmarks)
228
+ if self._has_eye_model:
229
+ left_crop, right_crop, left_bbox, right_bbox = extract_eye_crops(bgr_frame, landmarks)
230
+ out["left_bbox"] = left_bbox
231
+ out["right_bbox"] = right_bbox
232
+ s_eye_model = self.eye_classifier.predict_score([left_crop, right_crop])
233
+ out["s_eye"] = (1.0 - self.eye_blend) * s_eye_geo + self.eye_blend * s_eye_model
234
+ else:
235
+ out["s_eye"] = s_eye_geo
236
+
237
+ out["mar"] = compute_mar(landmarks)
238
+ out["is_yawning"] = out["mar"] > MAR_YAWN_THRESHOLD
239
+
240
+ raw = self.alpha * out["s_face"] + self.beta * out["s_eye"]
241
+ if out["is_yawning"]:
242
+ raw = 0.0
243
+ out["raw_score"] = self._smoother.update(raw, True)
244
+ out["is_focused"] = out["raw_score"] >= self.threshold
245
+
246
+ return out
247
+
248
+ @property
249
+ def has_eye_model(self) -> bool:
250
+ return self._has_eye_model
251
+
252
+ def close(self):
253
+ if self._owns_detector:
254
+ self.detector.close()
255
+
256
+ def __enter__(self):
257
+ return self
258
+
259
+ def __exit__(self, *args):
260
+ self.close()
261
+
262
+
263
+ def _latest_model_artifacts(model_dir):
264
+ model_files = sorted(glob.glob(os.path.join(model_dir, "model_*.joblib")))
265
+ if not model_files:
266
+ model_files = sorted(glob.glob(os.path.join(model_dir, "mlp_*.joblib")))
267
+ if not model_files:
268
+ return None, None, None
269
+ basename = os.path.basename(model_files[-1])
270
+ for prefix in ("model_", "mlp_"):
271
+ if basename.startswith(prefix):
272
+ tag = basename[len(prefix) :].replace(".joblib", "")
273
+ break
274
+ scaler_path = os.path.join(model_dir, f"scaler_{tag}.joblib")
275
+ meta_path = os.path.join(model_dir, f"meta_{tag}.npz")
276
+ if not os.path.isfile(scaler_path) or not os.path.isfile(meta_path):
277
+ return None, None, None
278
+ return model_files[-1], scaler_path, meta_path
279
+
280
+
281
+ def _load_hybrid_config(model_dir: str, config_path: str | None = None):
282
+ cfg = dict(DEFAULT_HYBRID_CONFIG)
283
+ resolved = config_path or os.path.join(model_dir, "hybrid_focus_config.json")
284
+ if not os.path.isfile(resolved):
285
+ print(f"[HYBRID] No config found at {resolved}; using defaults")
286
+ return cfg, None
287
+
288
+ with open(resolved, "r", encoding="utf-8") as f:
289
+ file_cfg = json.load(f)
290
+
291
+ for key in DEFAULT_HYBRID_CONFIG:
292
+ if key in file_cfg:
293
+ cfg[key] = file_cfg[key]
294
+
295
+ cfg["w_mlp"] = float(cfg["w_mlp"])
296
+ cfg["w_geo"] = float(cfg["w_geo"])
297
+ weight_sum = cfg["w_mlp"] + cfg["w_geo"]
298
+ if weight_sum <= 0:
299
+ raise ValueError("[HYBRID] Invalid config: w_mlp + w_geo must be > 0")
300
+ cfg["w_mlp"] /= weight_sum
301
+ cfg["w_geo"] /= weight_sum
302
+ cfg["threshold"] = float(cfg["threshold"])
303
+ cfg["use_yawn_veto"] = bool(cfg["use_yawn_veto"])
304
+ cfg["geo_face_weight"] = float(cfg["geo_face_weight"])
305
+ cfg["geo_eye_weight"] = float(cfg["geo_eye_weight"])
306
+ cfg["mar_yawn_threshold"] = float(cfg["mar_yawn_threshold"])
307
+
308
+ print(f"[HYBRID] Loaded config: {resolved}")
309
+ return cfg, resolved
310
+
311
+
312
+ class MLPPipeline:
313
+ def __init__(self, model_dir=None, detector=None):
314
+ if model_dir is None:
315
+ model_dir = os.path.join(_PROJECT_ROOT, "checkpoints")
316
+ mlp_path, scaler_path, meta_path = _latest_model_artifacts(model_dir)
317
+ if mlp_path is None:
318
+ raise FileNotFoundError(f"No MLP artifacts in {model_dir}")
319
+ self._mlp = joblib.load(mlp_path)
320
+ self._scaler = joblib.load(scaler_path)
321
+ meta = np.load(meta_path, allow_pickle=True)
322
+ self._feature_names = list(meta["feature_names"])
323
+
324
+ norm_feats = list(meta["norm_features"]) if "norm_features" in meta else []
325
+ self._engine = _RuntimeFeatureEngine(FEATURE_NAMES, norm_features=norm_feats)
326
+ ext_names = self._engine.extended_names
327
+ self._indices = [ext_names.index(n) for n in self._feature_names]
328
+
329
+ self._detector = detector or FaceMeshDetector()
330
+ self._owns_detector = detector is None
331
+ self._head_pose = HeadPoseEstimator()
332
+ self.head_pose = self._head_pose
333
+ self._eye_scorer = EyeBehaviourScorer()
334
+ self._temporal = TemporalTracker()
335
+ self._smoother = _OutputSmoother()
336
+ self._threshold = 0.5
337
+ print(f"[MLP] Loaded {mlp_path} | {len(self._feature_names)} features")
338
+
339
+ def process_frame(self, bgr_frame):
340
+ landmarks = self._detector.process(bgr_frame)
341
+ h, w = bgr_frame.shape[:2]
342
+ out = {
343
+ "landmarks": landmarks,
344
+ "is_focused": False,
345
+ "s_face": 0.0,
346
+ "s_eye": 0.0,
347
+ "raw_score": 0.0,
348
+ "mlp_prob": 0.0,
349
+ "mar": None,
350
+ "yaw": None,
351
+ "pitch": None,
352
+ "roll": None,
353
+ }
354
+ if landmarks is None:
355
+ smoothed = self._smoother.update(0.0, False)
356
+ out["raw_score"] = smoothed
357
+ out["is_focused"] = smoothed >= self._threshold
358
+ return out
359
+ vec = extract_features(landmarks, w, h, self._head_pose, self._eye_scorer, self._temporal)
360
+ vec = _clip_features(vec)
361
+
362
+ out["yaw"] = float(vec[_FEAT_IDX["yaw"]])
363
+ out["pitch"] = float(vec[_FEAT_IDX["pitch"]])
364
+ out["roll"] = float(vec[_FEAT_IDX["roll"]])
365
+ out["s_face"] = float(vec[_FEAT_IDX["s_face"]])
366
+ out["s_eye"] = float(vec[_FEAT_IDX["s_eye"]])
367
+ out["mar"] = float(vec[_FEAT_IDX["mar"]])
368
+
369
+ ext_vec = self._engine.transform(vec)
370
+ X = ext_vec[self._indices].reshape(1, -1).astype(np.float64)
371
+ X_sc = self._scaler.transform(X)
372
+ if hasattr(self._mlp, "predict_proba"):
373
+ mlp_prob = float(self._mlp.predict_proba(X_sc)[0, 1])
374
+ else:
375
+ mlp_prob = float(self._mlp.predict(X_sc)[0] == 1)
376
+ out["mlp_prob"] = float(np.clip(mlp_prob, 0.0, 1.0))
377
+ out["raw_score"] = self._smoother.update(out["mlp_prob"], True)
378
+ out["is_focused"] = out["raw_score"] >= self._threshold
379
+ return out
380
+
381
+ def close(self):
382
+ if self._owns_detector:
383
+ self._detector.close()
384
+
385
+ def __enter__(self):
386
+ return self
387
+
388
+ def __exit__(self, *args):
389
+ self.close()
390
+
391
+
392
+ class HybridFocusPipeline:
393
+ def __init__(
394
+ self,
395
+ model_dir=None,
396
+ config_path: str | None = None,
397
+ eye_model_path: str | None = None,
398
+ eye_backend: str = "yolo",
399
+ eye_blend: float = 0.5,
400
+ max_angle: float = 22.0,
401
+ detector=None,
402
+ ):
403
+ if model_dir is None:
404
+ model_dir = os.path.join(_PROJECT_ROOT, "checkpoints")
405
+ mlp_path, scaler_path, meta_path = _latest_model_artifacts(model_dir)
406
+ if mlp_path is None:
407
+ raise FileNotFoundError(f"No MLP artifacts in {model_dir}")
408
+
409
+ self._mlp = joblib.load(mlp_path)
410
+ self._scaler = joblib.load(scaler_path)
411
+ meta = np.load(meta_path, allow_pickle=True)
412
+ self._feature_names = list(meta["feature_names"])
413
+
414
+ norm_feats = list(meta["norm_features"]) if "norm_features" in meta else []
415
+ self._engine = _RuntimeFeatureEngine(FEATURE_NAMES, norm_features=norm_feats)
416
+ ext_names = self._engine.extended_names
417
+ self._indices = [ext_names.index(n) for n in self._feature_names]
418
+
419
+ self._cfg, self._cfg_path = _load_hybrid_config(model_dir=model_dir, config_path=config_path)
420
+
421
+ self._detector = detector or FaceMeshDetector()
422
+ self._owns_detector = detector is None
423
+ self._head_pose = HeadPoseEstimator(max_angle=max_angle)
424
+ self._eye_scorer = EyeBehaviourScorer()
425
+ self._temporal = TemporalTracker()
426
+ self._eye_blend = eye_blend
427
+ self.eye_classifier = load_eye_classifier(
428
+ path=eye_model_path if eye_model_path and os.path.exists(eye_model_path) else None,
429
+ backend=eye_backend,
430
+ device="cpu",
431
+ )
432
+ self._has_eye_model = not isinstance(self.eye_classifier, GeometricOnlyClassifier)
433
+ if self._has_eye_model:
434
+ print(f"[HYBRID] Eye model: {self.eye_classifier.name}")
435
+
436
+ self.head_pose = self._head_pose
437
+ self._smoother = _OutputSmoother()
438
+
439
+ print(
440
+ f"[HYBRID] Loaded {mlp_path} | {len(self._feature_names)} features | "
441
+ f"w_mlp={self._cfg['w_mlp']:.2f}, w_geo={self._cfg['w_geo']:.2f}, "
442
+ f"threshold={self._cfg['threshold']:.2f}"
443
+ )
444
+
445
+ @property
446
+ def has_eye_model(self) -> bool:
447
+ return self._has_eye_model
448
+
449
+ @property
450
+ def config(self) -> dict:
451
+ return dict(self._cfg)
452
+
453
+ def process_frame(self, bgr_frame: np.ndarray) -> dict:
454
+ landmarks = self._detector.process(bgr_frame)
455
+ h, w = bgr_frame.shape[:2]
456
+ out = {
457
+ "landmarks": landmarks,
458
+ "is_focused": False,
459
+ "focus_score": 0.0,
460
+ "mlp_prob": 0.0,
461
+ "geo_score": 0.0,
462
+ "raw_score": 0.0,
463
+ "s_face": 0.0,
464
+ "s_eye": 0.0,
465
+ "mar": None,
466
+ "is_yawning": False,
467
+ "yaw": None,
468
+ "pitch": None,
469
+ "roll": None,
470
+ "left_bbox": None,
471
+ "right_bbox": None,
472
+ }
473
+ if landmarks is None:
474
+ smoothed = self._smoother.update(0.0, False)
475
+ out["focus_score"] = smoothed
476
+ out["raw_score"] = smoothed
477
+ out["is_focused"] = smoothed >= self._cfg["threshold"]
478
+ return out
479
+
480
+ angles = self._head_pose.estimate(landmarks, w, h)
481
+ if angles is not None:
482
+ out["yaw"], out["pitch"], out["roll"] = angles
483
+
484
+ out["s_face"] = self._head_pose.score(landmarks, w, h)
485
+ s_eye_geo = self._eye_scorer.score(landmarks)
486
+ if self._has_eye_model:
487
+ left_crop, right_crop, left_bbox, right_bbox = extract_eye_crops(bgr_frame, landmarks)
488
+ out["left_bbox"] = left_bbox
489
+ out["right_bbox"] = right_bbox
490
+ s_eye_model = self.eye_classifier.predict_score([left_crop, right_crop])
491
+ out["s_eye"] = (1.0 - self._eye_blend) * s_eye_geo + self._eye_blend * s_eye_model
492
+ else:
493
+ out["s_eye"] = s_eye_geo
494
+
495
+ geo_score = (
496
+ self._cfg["geo_face_weight"] * out["s_face"] +
497
+ self._cfg["geo_eye_weight"] * out["s_eye"]
498
+ )
499
+ geo_score = float(np.clip(geo_score, 0.0, 1.0))
500
+
501
+ out["mar"] = compute_mar(landmarks)
502
+ out["is_yawning"] = out["mar"] > self._cfg["mar_yawn_threshold"]
503
+ if self._cfg["use_yawn_veto"] and out["is_yawning"]:
504
+ geo_score = 0.0
505
+ out["geo_score"] = geo_score
506
+
507
+ pre = {
508
+ "angles": angles,
509
+ "s_face": out["s_face"],
510
+ "s_eye": s_eye_geo,
511
+ "mar": out["mar"],
512
+ }
513
+ vec = extract_features(landmarks, w, h, self._head_pose, self._eye_scorer, self._temporal, _pre=pre)
514
+ vec = _clip_features(vec)
515
+ ext_vec = self._engine.transform(vec)
516
+ X = ext_vec[self._indices].reshape(1, -1).astype(np.float64)
517
+ X_sc = self._scaler.transform(X)
518
+ if hasattr(self._mlp, "predict_proba"):
519
+ mlp_prob = float(self._mlp.predict_proba(X_sc)[0, 1])
520
+ else:
521
+ mlp_prob = float(self._mlp.predict(X_sc)[0] == 1)
522
+ out["mlp_prob"] = float(np.clip(mlp_prob, 0.0, 1.0))
523
+
524
+ focus_score = self._cfg["w_mlp"] * out["mlp_prob"] + self._cfg["w_geo"] * out["geo_score"]
525
+ out["focus_score"] = self._smoother.update(float(np.clip(focus_score, 0.0, 1.0)), True)
526
+ out["raw_score"] = out["focus_score"]
527
+ out["is_focused"] = out["focus_score"] >= self._cfg["threshold"]
528
+ return out
529
+
530
+ def close(self):
531
+ if self._owns_detector:
532
+ self._detector.close()
533
+
534
+ def __enter__(self):
535
+ return self
536
+
537
+ def __exit__(self, *args):
538
+ self.close()
539
+
540
+
541
+ # ---------------------------------------------------------------------------
542
+ # GRU Pipeline
543
+ # ---------------------------------------------------------------------------
544
+
545
+ def _load_gru_artifacts(model_dir=None):
546
+ if model_dir is None:
547
+ model_dir = os.path.join(_PROJECT_ROOT, "checkpoints")
548
+ pt_path = os.path.join(model_dir, "gru_best.pt")
549
+ scaler_path = os.path.join(model_dir, "gru_scaler_best.npz")
550
+ meta_path = os.path.join(model_dir, "gru_meta_best.npz")
551
+ if not all(os.path.isfile(p) for p in [pt_path, scaler_path, meta_path]):
552
+ return None, None, None
553
+ return pt_path, scaler_path, meta_path
554
+
555
+
556
+ class _AttentionGRU:
557
+
558
+ def __init__(self, pt_path, input_size, hidden_size=64, num_layers=2, dropout=0.3):
559
+ import torch
560
+ import torch.nn as nn
561
+
562
+ class _GRUNet(nn.Module):
563
+ def __init__(self, in_sz, h_sz, n_layers, drop):
564
+ super().__init__()
565
+ self.gru = nn.GRU(
566
+ input_size=in_sz, hidden_size=h_sz,
567
+ num_layers=n_layers, batch_first=True,
568
+ dropout=drop if n_layers > 1 else 0.0,
569
+ )
570
+ self.classifier = nn.Sequential(
571
+ nn.Dropout(drop),
572
+ nn.Linear(h_sz, 32),
573
+ nn.ReLU(),
574
+ nn.Dropout(drop * 0.5),
575
+ nn.Linear(32, 1),
576
+ )
577
+
578
+ def forward(self, x):
579
+ gru_out, _ = self.gru(x)
580
+ return self.classifier(gru_out[:, -1, :])
581
+
582
+ self._device = torch.device("cpu")
583
+ self._model = _GRUNet(input_size, hidden_size, num_layers, dropout)
584
+ checkpoint = torch.load(pt_path, map_location=self._device, weights_only=False)
585
+ if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint:
586
+ self._model.load_state_dict(checkpoint["model_state_dict"])
587
+ else:
588
+ self._model.load_state_dict(checkpoint)
589
+ self._model.eval()
590
+
591
+ def predict_proba(self, x_np):
592
+ """x_np: (1, window, features) numpy array -> float probability of focused."""
593
+ import torch
594
+ with torch.no_grad():
595
+ t = torch.tensor(x_np, dtype=torch.float32, device=self._device)
596
+ logit = self._model(t)
597
+ prob = torch.sigmoid(logit).item()
598
+ return prob
599
+
600
+
601
+ class GRUPipeline:
602
+
603
+ def __init__(self, model_dir=None, detector=None):
604
+ pt_path, scaler_path, meta_path = _load_gru_artifacts(model_dir)
605
+ if pt_path is None:
606
+ d = model_dir or os.path.join(_PROJECT_ROOT, "checkpoints")
607
+ raise FileNotFoundError(f"No GRU artifacts in {d}")
608
+
609
+ meta = np.load(meta_path, allow_pickle=True)
610
+ self._feature_names = list(meta["feature_names"])
611
+ self._window_size = int(meta["window_size"])
612
+ hidden_size = int(meta["hidden_size"])
613
+ num_layers = int(meta["num_layers"])
614
+ dropout = float(meta["dropout"])
615
+ self._threshold = float(meta["default_threshold"])
616
+
617
+ sc = np.load(scaler_path)
618
+ self._sc_mean = sc["mean"]
619
+ self._sc_scale = sc["scale"]
620
+
621
+ self._gru = _AttentionGRU(
622
+ pt_path, input_size=len(self._feature_names),
623
+ hidden_size=hidden_size, num_layers=num_layers, dropout=dropout,
624
+ )
625
+
626
+ self._feat_indices = [FEATURE_NAMES.index(n) for n in self._feature_names]
627
+
628
+ self._detector = detector or FaceMeshDetector()
629
+ self._owns_detector = detector is None
630
+ self._head_pose = HeadPoseEstimator()
631
+ self.head_pose = self._head_pose
632
+ self._eye_scorer = EyeBehaviourScorer()
633
+ self._temporal = TemporalTracker()
634
+ self._smoother = _OutputSmoother(alpha=0.6, grace_frames=10)
635
+
636
+ self._buffer = collections.deque(maxlen=self._window_size)
637
+
638
+ print(
639
+ f"[GRU] Loaded {pt_path} | {len(self._feature_names)} features | "
640
+ f"window={self._window_size} | threshold={self._threshold:.3f}"
641
+ )
642
+
643
+ def process_frame(self, bgr_frame):
644
+ landmarks = self._detector.process(bgr_frame)
645
+ h, w = bgr_frame.shape[:2]
646
+ out = {
647
+ "landmarks": landmarks,
648
+ "is_focused": False,
649
+ "raw_score": 0.0,
650
+ "gru_prob": 0.0,
651
+ "s_face": 0.0,
652
+ "s_eye": 0.0,
653
+ "mar": None,
654
+ "yaw": None,
655
+ "pitch": None,
656
+ "roll": None,
657
+ }
658
+ if landmarks is None:
659
+ smoothed = self._smoother.update(0.0, False)
660
+ out["raw_score"] = smoothed
661
+ out["is_focused"] = smoothed >= self._threshold
662
+ return out
663
+
664
+ vec = extract_features(landmarks, w, h, self._head_pose, self._eye_scorer, self._temporal)
665
+ vec = _clip_features(vec)
666
+
667
+ out["yaw"] = float(vec[_FEAT_IDX["yaw"]])
668
+ out["pitch"] = float(vec[_FEAT_IDX["pitch"]])
669
+ out["roll"] = float(vec[_FEAT_IDX["roll"]])
670
+ out["s_face"] = float(vec[_FEAT_IDX["s_face"]])
671
+ out["s_eye"] = float(vec[_FEAT_IDX["s_eye"]])
672
+ out["mar"] = float(vec[_FEAT_IDX["mar"]])
673
+
674
+ selected = vec[self._feat_indices].astype(np.float64)
675
+ scaled = (selected - self._sc_mean) / np.maximum(self._sc_scale, 1e-8)
676
+ scaled_f32 = scaled.astype(np.float32)
677
+
678
+ # Pad buffer on first frame so GRU can predict immediately
679
+ if len(self._buffer) == 0:
680
+ for _ in range(self._window_size):
681
+ self._buffer.append(scaled_f32)
682
+ else:
683
+ self._buffer.append(scaled_f32)
684
+
685
+ window = np.array(self._buffer)[np.newaxis, :, :] # (1, W, F)
686
+ gru_prob = self._gru.predict_proba(window)
687
+ out["gru_prob"] = float(np.clip(gru_prob, 0.0, 1.0))
688
+ out["raw_score"] = self._smoother.update(out["gru_prob"], True)
689
+ out["is_focused"] = out["raw_score"] >= self._threshold
690
+ return out
691
+
692
+ def close(self):
693
+ if self._owns_detector:
694
+ self._detector.close()
695
+
696
+ def __enter__(self):
697
+ return self
698
+
699
+ def __exit__(self, *args):
700
+ self.close()
vite.config.js ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { defineConfig } from 'vite'
2
+ import react from '@vitejs/plugin-react'
3
+
4
+ // https://vitejs.dev/config/
5
+ export default defineConfig({
6
+ plugins: [react()],
7
+ server: {
8
+ proxy: {
9
+ // 告诉 Vite:凡是以 /api 开头的请求,都转发给 Python 后端
10
+ '/api': {
11
+ target: 'http://localhost:8000',
12
+ changeOrigin: true,
13
+ secure: false,
14
+ },
15
+ // 告诉 Vite:凡是以 /ws 开头的请求,都转发给 Python 后端
16
+ '/ws': {
17
+ target: 'ws://localhost:8000',
18
+ ws: true,
19
+ }
20
+ }
21
+ }
22
+ })