Spaces:
Sleeping
Sleeping
k22056537 commited on
Commit ·
eb4abb8
1
Parent(s): 9d50060
feat: sync integration updates across app and ML pipeline
Browse filesUpdate backend/frontend integration, dataset config flow, and training/evaluation scripts, and refresh generated evaluation reports to match the latest runs.
- README.md +62 -0
- api/__init__.py +1 -0
- api/db.py +201 -0
- api/drawing.py +124 -0
- config/__init__.py +57 -0
- config/default.yaml +79 -0
- data_preparation/prepare_dataset.py +47 -9
- evaluation/GROUPED_SPLIT_BENCHMARK.md +13 -0
- evaluation/README.md +4 -1
- evaluation/THRESHOLD_JUSTIFICATION.md +19 -125
- evaluation/feature_importance.py +106 -57
- evaluation/feature_selection_justification.md +15 -16
- evaluation/grouped_split_benchmark.py +107 -0
- evaluation/justify_thresholds.py +17 -14
- evaluation/plots/roc_xgb.png +0 -0
- main.py +57 -341
- models/L2CS-Net/l2cs/datasets.py +0 -10
- models/mlp/eval_accuracy.py +0 -2
- models/mlp/sweep.py +3 -3
- models/mlp/train.py +169 -51
- models/xgboost/add_accuracy.py +1 -3
- models/xgboost/config.py +52 -0
- models/xgboost/eval_accuracy.py +0 -2
- models/xgboost/sweep_local.py +2 -3
- models/xgboost/train.py +126 -57
- requirements.txt +1 -0
- src/App.jsx +1 -1
- src/components/Achievement.jsx +0 -16
- src/components/Customise.jsx +1 -1
- src/components/FocusPageLocal.jsx +5 -9
- src/utils/VideoManagerLocal.js +5 -1
- tests/test_api_settings.py +4 -18
- tests/test_data_preparation.py +18 -5
README.md
CHANGED
|
@@ -2,6 +2,8 @@
|
|
| 2 |
|
| 3 |
Webcam-based focus detection: MediaPipe face mesh -> 17 features (EAR, gaze, head pose, PERCLOS, etc.) -> MLP or XGBoost for focused/unfocused. React + FastAPI app with WebSocket video.
|
| 4 |
|
|
|
|
|
|
|
| 5 |
## Project layout
|
| 6 |
|
| 7 |
```
|
|
@@ -27,6 +29,10 @@ Webcam-based focus detection: MediaPipe face mesh -> 17 features (EAR, gaze, hea
|
|
| 27 |
└── package.json
|
| 28 |
```
|
| 29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
## Setup
|
| 31 |
|
| 32 |
```bash
|
|
@@ -74,10 +80,30 @@ python -m models.mlp.train
|
|
| 74 |
python -m models.xgboost.train
|
| 75 |
```
|
| 76 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
## Data
|
| 78 |
|
| 79 |
9 participants, 144,793 samples, 10 features, binary labels. Collect with `python -m models.collect_features --name <name>`. Data lives in `data/collected_<name>/`.
|
| 80 |
|
|
|
|
|
|
|
| 81 |
## Models
|
| 82 |
|
| 83 |
| Model | What it uses | Best for |
|
|
@@ -95,6 +121,42 @@ python -m models.xgboost.train
|
|
| 95 |
| XGBoost (600 trees, depth 8) | 95.87% | 0.959 | 0.991 |
|
| 96 |
| MLP (64->32) | 92.92% | 0.929 | 0.971 |
|
| 97 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
## L2CS Gaze Tracking
|
| 99 |
|
| 100 |
L2CS-Net predicts where your eyes are looking, not just where your head is pointed. This catches the scenario where your head faces the screen but your eyes wander.
|
|
|
|
| 2 |
|
| 3 |
Webcam-based focus detection: MediaPipe face mesh -> 17 features (EAR, gaze, head pose, PERCLOS, etc.) -> MLP or XGBoost for focused/unfocused. React + FastAPI app with WebSocket video.
|
| 4 |
|
| 5 |
+
**Repository:** Add your repo link here (e.g. `https://github.com/your-org/FocusGuard`).
|
| 6 |
+
|
| 7 |
## Project layout
|
| 8 |
|
| 9 |
```
|
|
|
|
| 29 |
└── package.json
|
| 30 |
```
|
| 31 |
|
| 32 |
+
## Config
|
| 33 |
+
|
| 34 |
+
Hyperparameters and app settings live in `config/default.yaml` (learning rates, batch size, thresholds, L2CS weights, etc.). Override with env `FOCUSGUARD_CONFIG` pointing to another YAML.
|
| 35 |
+
|
| 36 |
## Setup
|
| 37 |
|
| 38 |
```bash
|
|
|
|
| 80 |
python -m models.xgboost.train
|
| 81 |
```
|
| 82 |
|
| 83 |
+
### ClearML experiment tracking
|
| 84 |
+
|
| 85 |
+
All training and evaluation config (from `config/default.yaml`) is exposed as ClearML task parameters. Enable logging with `USE_CLEARML=1`; optionally run on a **remote GPU agent** instead of locally:
|
| 86 |
+
|
| 87 |
+
```bash
|
| 88 |
+
USE_CLEARML=1 CLEARML_QUEUE=gpu python -m models.mlp.train
|
| 89 |
+
USE_CLEARML=1 CLEARML_QUEUE=gpu python -m models.xgboost.train
|
| 90 |
+
USE_CLEARML=1 CLEARML_QUEUE=gpu python -m evaluation.justify_thresholds --clearml
|
| 91 |
+
```
|
| 92 |
+
|
| 93 |
+
The script enqueues the task and exits; a `clearml-agent` listening on the named queue (e.g. `gpu`) runs the same command with the same parameters. Start an agent with:
|
| 94 |
+
|
| 95 |
+
```bash
|
| 96 |
+
clearml-agent daemon --queue gpu
|
| 97 |
+
```
|
| 98 |
+
|
| 99 |
+
Logged to ClearML: **parameters** (full flattened config), **scalars** (loss, accuracy, F1, ROC-AUC, per-class precision/recall/F1, dataset sizes and class counts), **artifacts** (best checkpoint, training log JSON), and **plots** (confusion matrix, ROC curves in evaluation).
|
| 100 |
+
|
| 101 |
## Data
|
| 102 |
|
| 103 |
9 participants, 144,793 samples, 10 features, binary labels. Collect with `python -m models.collect_features --name <name>`. Data lives in `data/collected_<name>/`.
|
| 104 |
|
| 105 |
+
**Train/val/test split:** All pooled training and evaluation use the same split for reproducibility. The test set is held out before any preprocessing; `StandardScaler` is fit on the training set only, then applied to val and test. Split ratios and random seed come from `config/default.yaml` (`data.split_ratios`, `mlp.seed`) via `data_preparation.prepare_dataset.get_default_split_config()`. MLP train, XGBoost train, eval_accuracy scripts, and benchmarks all use this single source so reported test accuracy is on the same held-out set.
|
| 106 |
+
|
| 107 |
## Models
|
| 108 |
|
| 109 |
| Model | What it uses | Best for |
|
|
|
|
| 121 |
| XGBoost (600 trees, depth 8) | 95.87% | 0.959 | 0.991 |
|
| 122 |
| MLP (64->32) | 92.92% | 0.929 | 0.971 |
|
| 123 |
|
| 124 |
+
## Model numbers (LOPO, 9 participants)
|
| 125 |
+
|
| 126 |
+
| Model | LOPO AUC | Best threshold (Youden's J) | F1 @ best threshold | F1 @ 0.50 |
|
| 127 |
+
|-------|----------|------------------------------|---------------------|------------|
|
| 128 |
+
| MLP | 0.8624 | 0.228 | 0.8578 | 0.8149 |
|
| 129 |
+
| XGBoost | 0.8695 | 0.280 | 0.8549 | 0.8324 |
|
| 130 |
+
|
| 131 |
+
From the latest `python -m evaluation.justify_thresholds` run:
|
| 132 |
+
- Best geometric face weight (`alpha`) = `0.7` (mean LOPO F1 = `0.8195`)
|
| 133 |
+
- Best hybrid MLP weight (`w_mlp`) = `0.3` (mean LOPO F1 = `0.8409`)
|
| 134 |
+
|
| 135 |
+
## Grouped vs pooled benchmark
|
| 136 |
+
|
| 137 |
+
Latest quick benchmark (`python -m evaluation.grouped_split_benchmark --quick`) shows the expected gap between pooled random split and person-held-out LOPO:
|
| 138 |
+
|
| 139 |
+
| Protocol | Accuracy | F1 (weighted) | ROC-AUC |
|
| 140 |
+
|----------|---------:|--------------:|--------:|
|
| 141 |
+
| Pooled random split | 0.9510 | 0.9507 | 0.9869 |
|
| 142 |
+
| Grouped LOPO (9 folds) | 0.8303 | 0.8304 | 0.8801 |
|
| 143 |
+
|
| 144 |
+
This is why LOPO is the primary generalisation metric for reporting.
|
| 145 |
+
|
| 146 |
+
## Feature ablation snapshot
|
| 147 |
+
|
| 148 |
+
Latest quick feature-selection run (`python -m evaluation.feature_importance --quick --skip-lofo`):
|
| 149 |
+
|
| 150 |
+
| Subset | Mean LOPO F1 |
|
| 151 |
+
|--------|--------------|
|
| 152 |
+
| all_10 | 0.8286 |
|
| 153 |
+
| eye_state | 0.8071 |
|
| 154 |
+
| head_pose | 0.7480 |
|
| 155 |
+
| gaze | 0.7260 |
|
| 156 |
+
|
| 157 |
+
Top-5 XGBoost gain features: `s_face`, `ear_right`, `head_deviation`, `ear_avg`, `perclos`.
|
| 158 |
+
For full leave-one-feature-out ablation, run `python -m evaluation.feature_importance` (slower).
|
| 159 |
+
|
| 160 |
## L2CS Gaze Tracking
|
| 161 |
|
| 162 |
L2CS-Net predicts where your eyes are looking, not just where your head is pointed. This catches the scenario where your head faces the screen but your eyes wander.
|
api/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# API package: db, drawing, routes, websocket.
|
api/db.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""SQLite DB for focus sessions and user settings."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import asyncio
|
| 6 |
+
import json
|
| 7 |
+
from datetime import datetime
|
| 8 |
+
|
| 9 |
+
import aiosqlite
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def get_db_path() -> str:
|
| 13 |
+
"""Database file path from config or default."""
|
| 14 |
+
try:
|
| 15 |
+
from config import get
|
| 16 |
+
return get("app.db_path") or "focus_guard.db"
|
| 17 |
+
except Exception:
|
| 18 |
+
return "focus_guard.db"
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
async def init_database(db_path: str | None = None) -> None:
|
| 22 |
+
"""Create focus_sessions, focus_events, user_settings tables if missing."""
|
| 23 |
+
path = db_path or get_db_path()
|
| 24 |
+
async with aiosqlite.connect(path) as db:
|
| 25 |
+
await db.execute("""
|
| 26 |
+
CREATE TABLE IF NOT EXISTS focus_sessions (
|
| 27 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 28 |
+
start_time TIMESTAMP NOT NULL,
|
| 29 |
+
end_time TIMESTAMP,
|
| 30 |
+
duration_seconds INTEGER DEFAULT 0,
|
| 31 |
+
focus_score REAL DEFAULT 0.0,
|
| 32 |
+
total_frames INTEGER DEFAULT 0,
|
| 33 |
+
focused_frames INTEGER DEFAULT 0,
|
| 34 |
+
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
| 35 |
+
)
|
| 36 |
+
""")
|
| 37 |
+
await db.execute("""
|
| 38 |
+
CREATE TABLE IF NOT EXISTS focus_events (
|
| 39 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 40 |
+
session_id INTEGER NOT NULL,
|
| 41 |
+
timestamp TIMESTAMP NOT NULL,
|
| 42 |
+
is_focused BOOLEAN NOT NULL,
|
| 43 |
+
confidence REAL NOT NULL,
|
| 44 |
+
detection_data TEXT,
|
| 45 |
+
FOREIGN KEY (session_id) REFERENCES focus_sessions (id)
|
| 46 |
+
)
|
| 47 |
+
""")
|
| 48 |
+
await db.execute("""
|
| 49 |
+
CREATE TABLE IF NOT EXISTS user_settings (
|
| 50 |
+
id INTEGER PRIMARY KEY CHECK (id = 1),
|
| 51 |
+
model_name TEXT DEFAULT 'mlp'
|
| 52 |
+
)
|
| 53 |
+
""")
|
| 54 |
+
await db.execute("""
|
| 55 |
+
INSERT OR IGNORE INTO user_settings (id, model_name)
|
| 56 |
+
VALUES (1, 'mlp')
|
| 57 |
+
""")
|
| 58 |
+
await db.commit()
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
async def create_session(db_path: str | None = None) -> int:
|
| 62 |
+
"""Insert a new focus session. Returns session id."""
|
| 63 |
+
path = db_path or get_db_path()
|
| 64 |
+
async with aiosqlite.connect(path) as db:
|
| 65 |
+
cursor = await db.execute(
|
| 66 |
+
"INSERT INTO focus_sessions (start_time) VALUES (?)",
|
| 67 |
+
(datetime.now().isoformat(),),
|
| 68 |
+
)
|
| 69 |
+
await db.commit()
|
| 70 |
+
return cursor.lastrowid
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
async def end_session(session_id: int, db_path: str | None = None) -> dict | None:
|
| 74 |
+
"""Close session and return summary (duration, focus_score, etc.)."""
|
| 75 |
+
path = db_path or get_db_path()
|
| 76 |
+
async with aiosqlite.connect(path) as db:
|
| 77 |
+
cursor = await db.execute(
|
| 78 |
+
"SELECT start_time, total_frames, focused_frames FROM focus_sessions WHERE id = ?",
|
| 79 |
+
(session_id,),
|
| 80 |
+
)
|
| 81 |
+
row = await cursor.fetchone()
|
| 82 |
+
if not row:
|
| 83 |
+
return None
|
| 84 |
+
start_time_str, total_frames, focused_frames = row
|
| 85 |
+
start_time = datetime.fromisoformat(start_time_str)
|
| 86 |
+
end_time = datetime.now()
|
| 87 |
+
duration = (end_time - start_time).total_seconds()
|
| 88 |
+
focus_score = focused_frames / total_frames if total_frames > 0 else 0.0
|
| 89 |
+
async with aiosqlite.connect(path) as db:
|
| 90 |
+
await db.execute("""
|
| 91 |
+
UPDATE focus_sessions
|
| 92 |
+
SET end_time = ?, duration_seconds = ?, focus_score = ?
|
| 93 |
+
WHERE id = ?
|
| 94 |
+
""", (end_time.isoformat(), int(duration), focus_score, session_id))
|
| 95 |
+
await db.commit()
|
| 96 |
+
return {
|
| 97 |
+
"session_id": session_id,
|
| 98 |
+
"start_time": start_time_str,
|
| 99 |
+
"end_time": end_time.isoformat(),
|
| 100 |
+
"duration_seconds": int(duration),
|
| 101 |
+
"focus_score": round(focus_score, 3),
|
| 102 |
+
"total_frames": total_frames,
|
| 103 |
+
"focused_frames": focused_frames,
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
async def store_focus_event(
|
| 108 |
+
session_id: int,
|
| 109 |
+
is_focused: bool,
|
| 110 |
+
confidence: float,
|
| 111 |
+
metadata: dict,
|
| 112 |
+
db_path: str | None = None,
|
| 113 |
+
) -> None:
|
| 114 |
+
"""Append one focus event and update session counters."""
|
| 115 |
+
path = db_path or get_db_path()
|
| 116 |
+
async with aiosqlite.connect(path) as db:
|
| 117 |
+
await db.execute("""
|
| 118 |
+
INSERT INTO focus_events (session_id, timestamp, is_focused, confidence, detection_data)
|
| 119 |
+
VALUES (?, ?, ?, ?, ?)
|
| 120 |
+
""", (session_id, datetime.now().isoformat(), is_focused, confidence, json.dumps(metadata)))
|
| 121 |
+
await db.execute("""
|
| 122 |
+
UPDATE focus_sessions
|
| 123 |
+
SET total_frames = total_frames + 1,
|
| 124 |
+
focused_frames = focused_frames + ?
|
| 125 |
+
WHERE id = ?
|
| 126 |
+
""", (1 if is_focused else 0, session_id))
|
| 127 |
+
await db.commit()
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
class EventBuffer:
|
| 131 |
+
"""Buffer focus events and flush to DB in batches to avoid per-frame writes."""
|
| 132 |
+
|
| 133 |
+
def __init__(self, db_path: str | None = None, flush_interval: float = 2.0):
|
| 134 |
+
self._db_path = db_path or get_db_path()
|
| 135 |
+
self._flush_interval = flush_interval
|
| 136 |
+
self._buf: list = []
|
| 137 |
+
self._lock = asyncio.Lock()
|
| 138 |
+
self._task: asyncio.Task | None = None
|
| 139 |
+
self._total_frames = 0
|
| 140 |
+
self._focused_frames = 0
|
| 141 |
+
|
| 142 |
+
def start(self) -> None:
|
| 143 |
+
if self._task is None:
|
| 144 |
+
self._task = asyncio.create_task(self._flush_loop())
|
| 145 |
+
|
| 146 |
+
async def stop(self) -> None:
|
| 147 |
+
if self._task:
|
| 148 |
+
self._task.cancel()
|
| 149 |
+
try:
|
| 150 |
+
await self._task
|
| 151 |
+
except asyncio.CancelledError:
|
| 152 |
+
pass
|
| 153 |
+
self._task = None
|
| 154 |
+
await self._flush()
|
| 155 |
+
|
| 156 |
+
def add(self, session_id: int, is_focused: bool, confidence: float, metadata: dict) -> None:
|
| 157 |
+
self._buf.append((
|
| 158 |
+
session_id,
|
| 159 |
+
datetime.now().isoformat(),
|
| 160 |
+
is_focused,
|
| 161 |
+
confidence,
|
| 162 |
+
json.dumps(metadata),
|
| 163 |
+
))
|
| 164 |
+
self._total_frames += 1
|
| 165 |
+
if is_focused:
|
| 166 |
+
self._focused_frames += 1
|
| 167 |
+
|
| 168 |
+
async def _flush_loop(self) -> None:
|
| 169 |
+
while True:
|
| 170 |
+
await asyncio.sleep(self._flush_interval)
|
| 171 |
+
await self._flush()
|
| 172 |
+
|
| 173 |
+
async def _flush(self) -> None:
|
| 174 |
+
async with self._lock:
|
| 175 |
+
if not self._buf:
|
| 176 |
+
return
|
| 177 |
+
batch = self._buf[:]
|
| 178 |
+
total = self._total_frames
|
| 179 |
+
focused = self._focused_frames
|
| 180 |
+
self._buf.clear()
|
| 181 |
+
self._total_frames = 0
|
| 182 |
+
self._focused_frames = 0
|
| 183 |
+
if not batch:
|
| 184 |
+
return
|
| 185 |
+
session_id = batch[0][0]
|
| 186 |
+
try:
|
| 187 |
+
async with aiosqlite.connect(self._db_path) as db:
|
| 188 |
+
await db.executemany("""
|
| 189 |
+
INSERT INTO focus_events (session_id, timestamp, is_focused, confidence, detection_data)
|
| 190 |
+
VALUES (?, ?, ?, ?, ?)
|
| 191 |
+
""", batch)
|
| 192 |
+
await db.execute("""
|
| 193 |
+
UPDATE focus_sessions
|
| 194 |
+
SET total_frames = total_frames + ?,
|
| 195 |
+
focused_frames = focused_frames + ?
|
| 196 |
+
WHERE id = ?
|
| 197 |
+
""", (total, focused, session_id))
|
| 198 |
+
await db.commit()
|
| 199 |
+
except Exception as e:
|
| 200 |
+
import logging
|
| 201 |
+
logging.getLogger(__name__).warning("DB flush error: %s", e)
|
api/drawing.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Server-side face mesh and HUD drawing for WebRTC/WS video frames."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import cv2
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
from mediapipe.tasks.python.vision import FaceLandmarksConnections
|
| 9 |
+
from models.face_mesh import FaceMeshDetector
|
| 10 |
+
|
| 11 |
+
_FONT = cv2.FONT_HERSHEY_SIMPLEX
|
| 12 |
+
_CYAN = (255, 255, 0)
|
| 13 |
+
_GREEN = (0, 255, 0)
|
| 14 |
+
_MAGENTA = (255, 0, 255)
|
| 15 |
+
_ORANGE = (0, 165, 255)
|
| 16 |
+
_RED = (0, 0, 255)
|
| 17 |
+
_WHITE = (255, 255, 255)
|
| 18 |
+
_LIGHT_GREEN = (144, 238, 144)
|
| 19 |
+
|
| 20 |
+
_TESSELATION_CONNS = [(c.start, c.end) for c in FaceLandmarksConnections.FACE_LANDMARKS_TESSELATION]
|
| 21 |
+
_CONTOUR_CONNS = [(c.start, c.end) for c in FaceLandmarksConnections.FACE_LANDMARKS_CONTOURS]
|
| 22 |
+
_LEFT_EYEBROW = [70, 63, 105, 66, 107, 55, 65, 52, 53, 46]
|
| 23 |
+
_RIGHT_EYEBROW = [300, 293, 334, 296, 336, 285, 295, 282, 283, 276]
|
| 24 |
+
_NOSE_BRIDGE = [6, 197, 195, 5, 4, 1, 19, 94, 2]
|
| 25 |
+
_LIPS_OUTER = [61, 146, 91, 181, 84, 17, 314, 405, 321, 375, 291, 409, 270, 269, 267, 0, 37, 39, 40, 185, 61]
|
| 26 |
+
_LIPS_INNER = [78, 95, 88, 178, 87, 14, 317, 402, 318, 324, 308, 415, 310, 311, 312, 13, 82, 81, 80, 191, 78]
|
| 27 |
+
_LEFT_EAR_POINTS = [33, 160, 158, 133, 153, 145]
|
| 28 |
+
_RIGHT_EAR_POINTS = [362, 385, 387, 263, 373, 380]
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def _lm_px(lm: np.ndarray, idx: int, w: int, h: int) -> tuple[int, int]:
|
| 32 |
+
return (int(lm[idx, 0] * w), int(lm[idx, 1] * h))
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _draw_polyline(
|
| 36 |
+
frame: np.ndarray, lm: np.ndarray, indices: list[int], w: int, h: int, color: tuple, thickness: int
|
| 37 |
+
) -> None:
|
| 38 |
+
for i in range(len(indices) - 1):
|
| 39 |
+
cv2.line(
|
| 40 |
+
frame,
|
| 41 |
+
_lm_px(lm, indices[i], w, h),
|
| 42 |
+
_lm_px(lm, indices[i + 1], w, h),
|
| 43 |
+
color,
|
| 44 |
+
thickness,
|
| 45 |
+
cv2.LINE_AA,
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def draw_face_mesh(frame: np.ndarray, lm: np.ndarray, w: int, h: int) -> None:
|
| 50 |
+
"""Draw tessellation, contours, eyebrows, nose, lips, eyes, irises, gaze lines on frame."""
|
| 51 |
+
overlay = frame.copy()
|
| 52 |
+
for s, e in _TESSELATION_CONNS:
|
| 53 |
+
cv2.line(overlay, _lm_px(lm, s, w, h), _lm_px(lm, e, w, h), (200, 200, 200), 1, cv2.LINE_AA)
|
| 54 |
+
cv2.addWeighted(overlay, 0.3, frame, 0.7, 0, frame)
|
| 55 |
+
for s, e in _CONTOUR_CONNS:
|
| 56 |
+
cv2.line(frame, _lm_px(lm, s, w, h), _lm_px(lm, e, w, h), _CYAN, 1, cv2.LINE_AA)
|
| 57 |
+
_draw_polyline(frame, lm, _LEFT_EYEBROW, w, h, _LIGHT_GREEN, 2)
|
| 58 |
+
_draw_polyline(frame, lm, _RIGHT_EYEBROW, w, h, _LIGHT_GREEN, 2)
|
| 59 |
+
_draw_polyline(frame, lm, _NOSE_BRIDGE, w, h, _ORANGE, 1)
|
| 60 |
+
_draw_polyline(frame, lm, _LIPS_OUTER, w, h, _MAGENTA, 1)
|
| 61 |
+
_draw_polyline(frame, lm, _LIPS_INNER, w, h, (200, 0, 200), 1)
|
| 62 |
+
left_pts = np.array([_lm_px(lm, i, w, h) for i in FaceMeshDetector.LEFT_EYE_INDICES], dtype=np.int32)
|
| 63 |
+
cv2.polylines(frame, [left_pts], True, _GREEN, 2, cv2.LINE_AA)
|
| 64 |
+
right_pts = np.array([_lm_px(lm, i, w, h) for i in FaceMeshDetector.RIGHT_EYE_INDICES], dtype=np.int32)
|
| 65 |
+
cv2.polylines(frame, [right_pts], True, _GREEN, 2, cv2.LINE_AA)
|
| 66 |
+
for indices in [_LEFT_EAR_POINTS, _RIGHT_EAR_POINTS]:
|
| 67 |
+
for idx in indices:
|
| 68 |
+
cv2.circle(frame, _lm_px(lm, idx, w, h), 3, (0, 255, 255), -1, cv2.LINE_AA)
|
| 69 |
+
for iris_idx, eye_inner, eye_outer in [
|
| 70 |
+
(FaceMeshDetector.LEFT_IRIS_INDICES, 133, 33),
|
| 71 |
+
(FaceMeshDetector.RIGHT_IRIS_INDICES, 362, 263),
|
| 72 |
+
]:
|
| 73 |
+
iris_pts = np.array([_lm_px(lm, i, w, h) for i in iris_idx], dtype=np.int32)
|
| 74 |
+
center = iris_pts[0]
|
| 75 |
+
if len(iris_pts) >= 5:
|
| 76 |
+
radii = [np.linalg.norm(iris_pts[j] - center) for j in range(1, 5)]
|
| 77 |
+
radius = max(int(np.mean(radii)), 2)
|
| 78 |
+
cv2.circle(frame, tuple(center), radius, _MAGENTA, 2, cv2.LINE_AA)
|
| 79 |
+
cv2.circle(frame, tuple(center), 2, _WHITE, -1, cv2.LINE_AA)
|
| 80 |
+
eye_cx = int((lm[eye_inner, 0] + lm[eye_outer, 0]) / 2.0 * w)
|
| 81 |
+
eye_cy = int((lm[eye_inner, 1] + lm[eye_outer, 1]) / 2.0 * h)
|
| 82 |
+
dx, dy = center[0] - eye_cx, center[1] - eye_cy
|
| 83 |
+
cv2.line(
|
| 84 |
+
frame,
|
| 85 |
+
tuple(center),
|
| 86 |
+
(int(center[0] + dx * 3), int(center[1] + dy * 3)),
|
| 87 |
+
_RED,
|
| 88 |
+
1,
|
| 89 |
+
cv2.LINE_AA,
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def draw_hud(frame: np.ndarray, result: dict, model_name: str) -> None:
|
| 94 |
+
"""Draw status bar and detail overlay (FOCUSED/NOT FOCUSED, conf, s_face, s_eye, MAR, yawn)."""
|
| 95 |
+
h, w = frame.shape[:2]
|
| 96 |
+
is_focused = result["is_focused"]
|
| 97 |
+
status = "FOCUSED" if is_focused else "NOT FOCUSED"
|
| 98 |
+
color = _GREEN if is_focused else _RED
|
| 99 |
+
cv2.rectangle(frame, (0, 0), (w, 55), (0, 0, 0), -1)
|
| 100 |
+
cv2.putText(frame, status, (10, 28), _FONT, 0.8, color, 2, cv2.LINE_AA)
|
| 101 |
+
cv2.putText(frame, model_name.upper(), (w - 150, 28), _FONT, 0.45, _WHITE, 1, cv2.LINE_AA)
|
| 102 |
+
conf = result.get("mlp_prob", result.get("raw_score", 0.0))
|
| 103 |
+
mar_s = f" MAR:{result['mar']:.2f}" if result.get("mar") is not None else ""
|
| 104 |
+
sf, se = result.get("s_face", 0), result.get("s_eye", 0)
|
| 105 |
+
detail = f"conf:{conf:.2f} S_face:{sf:.2f} S_eye:{se:.2f}{mar_s}"
|
| 106 |
+
cv2.putText(frame, detail, (10, 48), _FONT, 0.4, _WHITE, 1, cv2.LINE_AA)
|
| 107 |
+
if result.get("yaw") is not None:
|
| 108 |
+
cv2.putText(
|
| 109 |
+
frame,
|
| 110 |
+
f"yaw:{result['yaw']:+.0f} pitch:{result['pitch']:+.0f} roll:{result['roll']:+.0f}",
|
| 111 |
+
(w - 280, 48),
|
| 112 |
+
_FONT,
|
| 113 |
+
0.4,
|
| 114 |
+
(180, 180, 180),
|
| 115 |
+
1,
|
| 116 |
+
cv2.LINE_AA,
|
| 117 |
+
)
|
| 118 |
+
if result.get("is_yawning"):
|
| 119 |
+
cv2.putText(frame, "YAWN", (10, 75), _FONT, 0.7, _ORANGE, 2, cv2.LINE_AA)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def get_tesselation_connections() -> list[tuple[int, int]]:
|
| 123 |
+
"""Return tessellation edge pairs for client-side face mesh (cached by client)."""
|
| 124 |
+
return list(_TESSELATION_CONNS)
|
config/__init__.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Load app and model config from YAML. Single source for hyperparameters and tunables."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Any
|
| 8 |
+
|
| 9 |
+
_CONFIG: dict[str, Any] | None = None
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def _default_path() -> Path:
|
| 13 |
+
return Path(__file__).resolve().parent / "default.yaml"
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def load_config(path: str | Path | None = None) -> dict[str, Any]:
|
| 17 |
+
"""Load YAML config. Uses FOCUSGUARD_CONFIG env or config/default.yaml."""
|
| 18 |
+
global _CONFIG
|
| 19 |
+
if _CONFIG is not None:
|
| 20 |
+
return _CONFIG
|
| 21 |
+
import yaml
|
| 22 |
+
p = path or os.environ.get("FOCUSGUARD_CONFIG") or _default_path()
|
| 23 |
+
p = Path(p)
|
| 24 |
+
if not p.is_file():
|
| 25 |
+
_CONFIG = {}
|
| 26 |
+
return _CONFIG
|
| 27 |
+
with open(p, "r", encoding="utf-8") as f:
|
| 28 |
+
_CONFIG = yaml.safe_load(f) or {}
|
| 29 |
+
return _CONFIG
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def get(key_path: str, default: Any = None) -> Any:
|
| 33 |
+
"""Return a nested config value. E.g. get('app.db_path'), get('mlp.epochs')."""
|
| 34 |
+
cfg = load_config()
|
| 35 |
+
for part in key_path.split("."):
|
| 36 |
+
if not isinstance(cfg, dict) or part not in cfg:
|
| 37 |
+
return default
|
| 38 |
+
cfg = cfg[part]
|
| 39 |
+
return cfg
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def flatten_for_clearml(cfg: dict[str, Any] | None = None, prefix: str = "") -> dict[str, Any]:
|
| 43 |
+
"""Flatten nested config so every value appears as a ClearML task parameter (no nested dicts)."""
|
| 44 |
+
cfg = cfg if cfg is not None else load_config()
|
| 45 |
+
out = {}
|
| 46 |
+
for k, v in cfg.items():
|
| 47 |
+
key = f"{prefix}/{k}" if prefix else k
|
| 48 |
+
if isinstance(v, dict) and v and not any(isinstance(x, (dict, list)) for x in v.values()):
|
| 49 |
+
for k2, v2 in v.items():
|
| 50 |
+
out[f"{key}/{k2}"] = v2
|
| 51 |
+
elif isinstance(v, dict) and v:
|
| 52 |
+
out.update(flatten_for_clearml(v, key))
|
| 53 |
+
elif isinstance(v, list):
|
| 54 |
+
out[key] = str(v)
|
| 55 |
+
else:
|
| 56 |
+
out[key] = v
|
| 57 |
+
return out
|
config/default.yaml
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# FocusGuard app and model config. Override with FOCUSGUARD_CONFIG env path if needed.
|
| 2 |
+
|
| 3 |
+
app:
|
| 4 |
+
db_path: "focus_guard.db"
|
| 5 |
+
inference_size: [640, 480]
|
| 6 |
+
inference_workers: 4
|
| 7 |
+
default_model: "mlp"
|
| 8 |
+
calibration_verify_target: [0.5, 0.5]
|
| 9 |
+
no_face_confidence_cap: 0.1
|
| 10 |
+
|
| 11 |
+
l2cs_boost:
|
| 12 |
+
base_weight: 0.35
|
| 13 |
+
l2cs_weight: 0.65
|
| 14 |
+
veto_threshold: 0.38
|
| 15 |
+
fused_threshold: 0.52
|
| 16 |
+
|
| 17 |
+
mlp:
|
| 18 |
+
model_name: "face_orientation"
|
| 19 |
+
epochs: 30
|
| 20 |
+
batch_size: 32
|
| 21 |
+
lr: 0.001
|
| 22 |
+
seed: 42
|
| 23 |
+
split_ratios: [0.7, 0.15, 0.15]
|
| 24 |
+
hidden_sizes: [64, 32]
|
| 25 |
+
|
| 26 |
+
xgboost:
|
| 27 |
+
n_estimators: 600
|
| 28 |
+
max_depth: 8
|
| 29 |
+
learning_rate: 0.1489
|
| 30 |
+
subsample: 0.9625
|
| 31 |
+
colsample_bytree: 0.9013
|
| 32 |
+
reg_alpha: 1.1407
|
| 33 |
+
reg_lambda: 2.4181
|
| 34 |
+
eval_metric: "logloss"
|
| 35 |
+
|
| 36 |
+
data:
|
| 37 |
+
split_ratios: [0.7, 0.15, 0.15]
|
| 38 |
+
clip:
|
| 39 |
+
yaw: [-45, 45]
|
| 40 |
+
pitch: [-30, 30]
|
| 41 |
+
roll: [-30, 30]
|
| 42 |
+
ear: [0, 0.85]
|
| 43 |
+
mar: [0, 1.0]
|
| 44 |
+
gaze_offset: [0, 0.50]
|
| 45 |
+
perclos: [0, 0.80]
|
| 46 |
+
blink_rate: [0, 30.0]
|
| 47 |
+
closure_duration: [0, 10.0]
|
| 48 |
+
yawn_duration: [0, 10.0]
|
| 49 |
+
|
| 50 |
+
pipeline:
|
| 51 |
+
geometric:
|
| 52 |
+
max_angle: 22.0
|
| 53 |
+
alpha: 0.7
|
| 54 |
+
beta: 0.3
|
| 55 |
+
threshold: 0.55
|
| 56 |
+
smoother:
|
| 57 |
+
alpha_up: 0.55
|
| 58 |
+
alpha_down: 0.45
|
| 59 |
+
grace_frames: 10
|
| 60 |
+
hybrid_defaults:
|
| 61 |
+
w_mlp: 0.3
|
| 62 |
+
w_geo: 0.7
|
| 63 |
+
threshold: 0.35
|
| 64 |
+
geo_face_weight: 0.7
|
| 65 |
+
geo_eye_weight: 0.3
|
| 66 |
+
mlp_threshold: 0.23
|
| 67 |
+
|
| 68 |
+
evaluation:
|
| 69 |
+
seed: 42
|
| 70 |
+
mlp_sklearn:
|
| 71 |
+
hidden_layer_sizes: [64, 32]
|
| 72 |
+
max_iter: 200
|
| 73 |
+
validation_fraction: 0.15
|
| 74 |
+
geo_weights:
|
| 75 |
+
face: 0.7
|
| 76 |
+
eye: 0.3
|
| 77 |
+
threshold_search:
|
| 78 |
+
alphas: [0.2, 0.85]
|
| 79 |
+
w_mlps: [0.3, 0.85]
|
data_preparation/prepare_dataset.py
CHANGED
|
@@ -1,3 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
import glob
|
| 3 |
|
|
@@ -9,6 +17,10 @@ torch = None
|
|
| 9 |
Dataset = object # type: ignore
|
| 10 |
DataLoader = None
|
| 11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
def _require_torch():
|
| 14 |
global torch, Dataset, DataLoader
|
|
@@ -90,9 +102,10 @@ def load_all_pooled(model_name: str = "face_orientation", data_dir: str = None):
|
|
| 90 |
npz_files = sorted(glob.glob(pattern))
|
| 91 |
|
| 92 |
if not npz_files:
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
|
|
|
| 96 |
|
| 97 |
all_X, all_y = [], []
|
| 98 |
all_names = None
|
|
@@ -178,8 +191,23 @@ def _generate_synthetic_data(model_name: str):
|
|
| 178 |
return features, labels
|
| 179 |
|
| 180 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
def _split_and_scale(features, labels, split_ratios, seed, scale):
|
| 182 |
-
"""
|
|
|
|
|
|
|
|
|
|
| 183 |
test_ratio = split_ratios[2]
|
| 184 |
val_ratio = split_ratios[1] / (split_ratios[0] + split_ratios[1])
|
| 185 |
|
|
@@ -196,7 +224,7 @@ def _split_and_scale(features, labels, split_ratios, seed, scale):
|
|
| 196 |
X_train = scaler.fit_transform(X_train)
|
| 197 |
X_val = scaler.transform(X_val)
|
| 198 |
X_test = scaler.transform(X_test)
|
| 199 |
-
print("[DATA] Applied StandardScaler (fitted on training split)")
|
| 200 |
|
| 201 |
splits = {
|
| 202 |
"X_train": X_train, "y_train": y_train,
|
|
@@ -208,8 +236,13 @@ def _split_and_scale(features, labels, split_ratios, seed, scale):
|
|
| 208 |
return splits, scaler
|
| 209 |
|
| 210 |
|
| 211 |
-
def get_numpy_splits(model_name: str, split_ratios=
|
| 212 |
-
"""Return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 213 |
features, labels = _load_real_data(model_name)
|
| 214 |
num_features = features.shape[1]
|
| 215 |
num_classes = int(labels.max()) + 1
|
|
@@ -219,8 +252,13 @@ def get_numpy_splits(model_name: str, split_ratios=(0.7, 0.15, 0.15), seed: int
|
|
| 219 |
return splits, num_features, num_classes, scaler
|
| 220 |
|
| 221 |
|
| 222 |
-
def get_dataloaders(model_name: str, batch_size: int = 32, split_ratios=
|
| 223 |
-
"""Return PyTorch DataLoaders for
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 224 |
_, _, dataloader_cls = _require_torch()
|
| 225 |
features, labels = _load_real_data(model_name)
|
| 226 |
num_features = features.shape[1]
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Single source for pooled train/val/test data and splits.
|
| 3 |
+
|
| 4 |
+
- Data: load_all_pooled() / load_per_person() from data/collected_*/*.npz (same pattern everywhere).
|
| 5 |
+
- Splits: get_numpy_splits() / get_dataloaders() use stratified train/val/test with a fixed seed from config.
|
| 6 |
+
- Test is held out before any preprocessing; StandardScaler is fit on train only, then applied to val and test.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
import os
|
| 10 |
import glob
|
| 11 |
|
|
|
|
| 17 |
Dataset = object # type: ignore
|
| 18 |
DataLoader = None
|
| 19 |
|
| 20 |
+
# Defaults for stratified split (overridden by config when available)
|
| 21 |
+
_DEFAULT_SPLIT_RATIOS = (0.7, 0.15, 0.15)
|
| 22 |
+
_DEFAULT_SPLIT_SEED = 42
|
| 23 |
+
|
| 24 |
|
| 25 |
def _require_torch():
|
| 26 |
global torch, Dataset, DataLoader
|
|
|
|
| 102 |
npz_files = sorted(glob.glob(pattern))
|
| 103 |
|
| 104 |
if not npz_files:
|
| 105 |
+
raise FileNotFoundError(
|
| 106 |
+
f"No .npz files matching {pattern}. "
|
| 107 |
+
"Collect data first with `python -m models.collect_features --name <name>`."
|
| 108 |
+
)
|
| 109 |
|
| 110 |
all_X, all_y = [], []
|
| 111 |
all_names = None
|
|
|
|
| 191 |
return features, labels
|
| 192 |
|
| 193 |
|
| 194 |
+
def get_default_split_config():
|
| 195 |
+
"""Return (split_ratios, seed) from config so all scripts use the same split. Reproducible and consistent."""
|
| 196 |
+
try:
|
| 197 |
+
from config import get
|
| 198 |
+
data = get("data") or {}
|
| 199 |
+
ratios = data.get("split_ratios", list(_DEFAULT_SPLIT_RATIOS))
|
| 200 |
+
seed = get("mlp.seed") or _DEFAULT_SPLIT_SEED
|
| 201 |
+
return (tuple(ratios), int(seed))
|
| 202 |
+
except Exception:
|
| 203 |
+
return (_DEFAULT_SPLIT_RATIOS, _DEFAULT_SPLIT_SEED)
|
| 204 |
+
|
| 205 |
+
|
| 206 |
def _split_and_scale(features, labels, split_ratios, seed, scale):
|
| 207 |
+
"""Stratified train/val/test split. Test is held out first; val is split from the rest.
|
| 208 |
+
No training data is used for validation or test. Scaler is fit on train only, then
|
| 209 |
+
applied to val and test (no leakage from val/test into scaling).
|
| 210 |
+
"""
|
| 211 |
test_ratio = split_ratios[2]
|
| 212 |
val_ratio = split_ratios[1] / (split_ratios[0] + split_ratios[1])
|
| 213 |
|
|
|
|
| 224 |
X_train = scaler.fit_transform(X_train)
|
| 225 |
X_val = scaler.transform(X_val)
|
| 226 |
X_test = scaler.transform(X_test)
|
| 227 |
+
print("[DATA] Applied StandardScaler (fitted on training split only)")
|
| 228 |
|
| 229 |
splits = {
|
| 230 |
"X_train": X_train, "y_train": y_train,
|
|
|
|
| 236 |
return splits, scaler
|
| 237 |
|
| 238 |
|
| 239 |
+
def get_numpy_splits(model_name: str, split_ratios=None, seed=None, scale: bool = True):
|
| 240 |
+
"""Return train/val/test numpy arrays. Uses config defaults for split_ratios/seed when None.
|
| 241 |
+
Same dataset and split logic as get_dataloaders for consistent evaluation."""
|
| 242 |
+
if split_ratios is None or seed is None:
|
| 243 |
+
_ratios, _seed = get_default_split_config()
|
| 244 |
+
split_ratios = split_ratios if split_ratios is not None else _ratios
|
| 245 |
+
seed = seed if seed is not None else _seed
|
| 246 |
features, labels = _load_real_data(model_name)
|
| 247 |
num_features = features.shape[1]
|
| 248 |
num_classes = int(labels.max()) + 1
|
|
|
|
| 252 |
return splits, num_features, num_classes, scaler
|
| 253 |
|
| 254 |
|
| 255 |
+
def get_dataloaders(model_name: str, batch_size: int = 32, split_ratios=None, seed=None, scale: bool = True):
|
| 256 |
+
"""Return PyTorch DataLoaders. Uses config defaults for split_ratios/seed when None.
|
| 257 |
+
Test set is held out before preprocessing; scaler fit on train only."""
|
| 258 |
+
if split_ratios is None or seed is None:
|
| 259 |
+
_ratios, _seed = get_default_split_config()
|
| 260 |
+
split_ratios = split_ratios if split_ratios is not None else _ratios
|
| 261 |
+
seed = seed if seed is not None else _seed
|
| 262 |
_, _, dataloader_cls = _require_torch()
|
| 263 |
features, labels = _load_real_data(model_name)
|
| 264 |
num_features = features.shape[1]
|
evaluation/GROUPED_SPLIT_BENCHMARK.md
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Grouped vs pooled split benchmark
|
| 2 |
+
|
| 3 |
+
This compares the same XGBoost config under two evaluation protocols.
|
| 4 |
+
|
| 5 |
+
Config: `{'n_estimators': 600, 'max_depth': 8, 'learning_rate': 0.1489, 'subsample': 0.9625, 'colsample_bytree': 0.9013, 'reg_alpha': 1.1407, 'reg_lambda': 2.4181, 'eval_metric': 'logloss'}`
|
| 6 |
+
Quick mode: yes (n_estimators=200)
|
| 7 |
+
|
| 8 |
+
| Protocol | Accuracy | F1 (weighted) | ROC-AUC |
|
| 9 |
+
|----------|---------:|--------------:|--------:|
|
| 10 |
+
| Pooled random split (70/15/15) | 0.9510 | 0.9507 | 0.9869 |
|
| 11 |
+
| Grouped LOPO (9 folds) | 0.8303 | 0.8304 | 0.8801 |
|
| 12 |
+
|
| 13 |
+
Use grouped LOPO as the primary generalisation metric when reporting model quality.
|
evaluation/README.md
CHANGED
|
@@ -14,6 +14,9 @@ python -m evaluation.justify_thresholds
|
|
| 14 |
|
| 15 |
(LOPO over 9 participants, Youden’s J, weight grid search; ~10–15 min.) Outputs go to `plots/` and the markdown file.
|
| 16 |
|
| 17 |
-
**Feature importance:** Run `python -m evaluation.feature_importance` for XGBoost gain
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
**Who writes here:** `models.mlp.train`, `models.xgboost.train`, `evaluation.justify_thresholds`, `evaluation.feature_importance`, and the notebooks.
|
|
|
|
| 14 |
|
| 15 |
(LOPO over 9 participants, Youden’s J, weight grid search; ~10–15 min.) Outputs go to `plots/` and the markdown file.
|
| 16 |
|
| 17 |
+
**Feature importance:** Run `python -m evaluation.feature_importance` for full XGBoost gain + leave-one-feature-out LOPO (slow).
|
| 18 |
+
Fast iteration mode: `python -m evaluation.feature_importance --quick --skip-lofo` (channel ablation + gain only).
|
| 19 |
+
|
| 20 |
+
**Grouped benchmark:** Run `python -m evaluation.grouped_split_benchmark` for full run, or `python -m evaluation.grouped_split_benchmark --quick` for faster approximate numbers.
|
| 21 |
|
| 22 |
**Who writes here:** `models.mlp.train`, `models.xgboost.train`, `evaluation.justify_thresholds`, `evaluation.feature_importance`, and the notebooks.
|
evaluation/THRESHOLD_JUSTIFICATION.md
CHANGED
|
@@ -2,105 +2,31 @@
|
|
| 2 |
|
| 3 |
Auto-generated by `evaluation/justify_thresholds.py` using LOPO cross-validation over 9 participants (~145k samples).
|
| 4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
## 1. ML Model Decision Thresholds
|
| 6 |
|
|
|
|
|
|
|
| 7 |
Thresholds selected via **Youden's J statistic** (J = sensitivity + specificity - 1) on pooled LOPO held-out predictions.
|
| 8 |
|
| 9 |
| Model | LOPO AUC | Optimal Threshold (Youden's J) | F1 @ Optimal | F1 @ 0.50 |
|
| 10 |
|-------|----------|-------------------------------|--------------|-----------|
|
| 11 |
| MLP | 0.8624 | **0.228** | 0.8578 | 0.8149 |
|
| 12 |
-
| XGBoost | 0.
|
| 13 |
|
| 14 |

|
| 15 |
|
| 16 |

|
| 17 |
|
| 18 |
-
## 2.
|
| 19 |
-
|
| 20 |
-
At the optimal threshold (Youden's J), pooled over all LOPO held-out predictions:
|
| 21 |
-
|
| 22 |
-
| Model | Threshold | Precision | Recall | F1 | Accuracy |
|
| 23 |
-
|-------|----------:|----------:|-------:|---:|---------:|
|
| 24 |
-
| MLP | 0.228 | 0.8187 | 0.9008 | 0.8578 | 0.8164 |
|
| 25 |
-
| XGBoost | 0.377 | 0.8426 | 0.8750 | 0.8585 | 0.8228 |
|
| 26 |
-
|
| 27 |
-
Higher threshold → fewer positive predictions → higher precision, lower recall. Youden's J picks the threshold that balances sensitivity and specificity (recall for the positive class and true negative rate).
|
| 28 |
-
|
| 29 |
-
## 3. Confusion Matrix (Pooled LOPO)
|
| 30 |
-
|
| 31 |
-
At optimal threshold. Rows = true label, columns = predicted label (0 = unfocused, 1 = focused).
|
| 32 |
-
|
| 33 |
-
### MLP
|
| 34 |
-
|
| 35 |
-
| | Pred 0 | Pred 1 |
|
| 36 |
-
|--|-------:|-------:|
|
| 37 |
-
| **True 0** | 38065 (TN) | 17750 (FP) |
|
| 38 |
-
| **True 1** | 8831 (FN) | 80147 (TP) |
|
| 39 |
-
|
| 40 |
-
TN=38065, FP=17750, FN=8831, TP=80147.
|
| 41 |
-
|
| 42 |
-
### XGBoost
|
| 43 |
-
|
| 44 |
-
| | Pred 0 | Pred 1 |
|
| 45 |
-
|--|-------:|-------:|
|
| 46 |
-
| **True 0** | 41271 (TN) | 14544 (FP) |
|
| 47 |
-
| **True 1** | 11118 (FN) | 77860 (TP) |
|
| 48 |
-
|
| 49 |
-
TN=41271, FP=14544, FN=11118, TP=77860.
|
| 50 |
-
|
| 51 |
-

|
| 52 |
-
|
| 53 |
-

|
| 54 |
-
|
| 55 |
-
## 4. Per-Person Performance Variance (LOPO)
|
| 56 |
-
|
| 57 |
-
One fold per left-out person; metrics at optimal threshold.
|
| 58 |
-
|
| 59 |
-
### MLP — per held-out person
|
| 60 |
-
|
| 61 |
-
| Person | Accuracy | F1 | Precision | Recall |
|
| 62 |
-
|--------|---------:|---:|----------:|-------:|
|
| 63 |
-
| Abdelrahman | 0.8628 | 0.9029 | 0.8760 | 0.9314 |
|
| 64 |
-
| Jarek | 0.8400 | 0.8770 | 0.8909 | 0.8635 |
|
| 65 |
-
| Junhao | 0.8872 | 0.8986 | 0.8354 | 0.9723 |
|
| 66 |
-
| Kexin | 0.7941 | 0.8123 | 0.7965 | 0.8288 |
|
| 67 |
-
| Langyuan | 0.5877 | 0.6169 | 0.4972 | 0.8126 |
|
| 68 |
-
| Mohamed | 0.8432 | 0.8653 | 0.7931 | 0.9519 |
|
| 69 |
-
| Yingtao | 0.8794 | 0.9263 | 0.9217 | 0.9309 |
|
| 70 |
-
| ayten | 0.8307 | 0.8986 | 0.8558 | 0.9459 |
|
| 71 |
-
| saba | 0.9192 | 0.9243 | 0.9260 | 0.9226 |
|
| 72 |
-
|
| 73 |
-
### XGBoost — per held-out person
|
| 74 |
-
|
| 75 |
-
| Person | Accuracy | F1 | Precision | Recall |
|
| 76 |
-
|--------|---------:|---:|----------:|-------:|
|
| 77 |
-
| Abdelrahman | 0.8601 | 0.8959 | 0.9129 | 0.8795 |
|
| 78 |
-
| Jarek | 0.8680 | 0.8993 | 0.9070 | 0.8917 |
|
| 79 |
-
| Junhao | 0.9099 | 0.9180 | 0.8627 | 0.9810 |
|
| 80 |
-
| Kexin | 0.7363 | 0.7385 | 0.7906 | 0.6928 |
|
| 81 |
-
| Langyuan | 0.6738 | 0.6945 | 0.5625 | 0.9074 |
|
| 82 |
-
| Mohamed | 0.8868 | 0.8988 | 0.8529 | 0.9498 |
|
| 83 |
-
| Yingtao | 0.8711 | 0.9195 | 0.9347 | 0.9048 |
|
| 84 |
-
| ayten | 0.8451 | 0.9070 | 0.8654 | 0.9528 |
|
| 85 |
-
| saba | 0.9393 | 0.9421 | 0.9615 | 0.9235 |
|
| 86 |
-
|
| 87 |
-
### Summary across persons
|
| 88 |
-
|
| 89 |
-
| Model | Accuracy mean ± std | F1 mean ± std | Precision mean ± std | Recall mean ± std |
|
| 90 |
-
|-------|---------------------|---------------|----------------------|-------------------|
|
| 91 |
-
| MLP | 0.8271 ± 0.0968 | 0.8580 ± 0.0968 | 0.8214 ± 0.1307 | 0.9067 ± 0.0572 |
|
| 92 |
-
| XGBoost | 0.8434 ± 0.0847 | 0.8682 ± 0.0879 | 0.8500 ± 0.1191 | 0.8981 ± 0.0836 |
|
| 93 |
-
|
| 94 |
-
## 5. Confidence Intervals (95%, LOPO over 9 persons)
|
| 95 |
-
|
| 96 |
-
Mean ± half-width of 95% t-interval (df=8) for each metric across the 9 left-out persons.
|
| 97 |
-
|
| 98 |
-
| Model | F1 | Accuracy | Precision | Recall |
|
| 99 |
-
|-------|---:|--------:|----------:|-------:|
|
| 100 |
-
| MLP | 0.8580 [0.7835, 0.9326] | 0.8271 [0.7526, 0.9017] | 0.8214 [0.7207, 0.9221] | 0.9067 [0.8626, 0.9507] |
|
| 101 |
-
| XGBoost | 0.8682 [0.8005, 0.9358] | 0.8434 [0.7781, 0.9086] | 0.8500 [0.7583, 0.9417] | 0.8981 [0.8338, 0.9625] |
|
| 102 |
-
|
| 103 |
-
## 6. Geometric Pipeline Weights (s_face vs s_eye)
|
| 104 |
|
| 105 |
Grid search over face weight alpha in {0.2 ... 0.8}. Eye weight = 1 - alpha. Threshold per fold via Youden's J.
|
| 106 |
|
|
@@ -118,9 +44,9 @@ Grid search over face weight alpha in {0.2 ... 0.8}. Eye weight = 1 - alpha. Thr
|
|
| 118 |
|
| 119 |

|
| 120 |
|
| 121 |
-
##
|
| 122 |
|
| 123 |
-
Grid search over w_mlp in {0.3 ... 0.8}. w_geo = 1 - w_mlp. Geometric sub-score uses same weights as geometric pipeline (face=0.7, eye=0.3).
|
| 124 |
|
| 125 |
| MLP Weight (w_mlp) | Mean LOPO F1 |
|
| 126 |
|-------------------:|-------------:|
|
|
@@ -131,43 +57,11 @@ Grid search over w_mlp in {0.3 ... 0.8}. w_geo = 1 - w_mlp. Geometric sub-score
|
|
| 131 |
| 0.7 | 0.8039 |
|
| 132 |
| 0.8 | 0.8016 |
|
| 133 |
|
| 134 |
-
**Best:** w_mlp = 0.3 (MLP 30%, geometric 70%)
|
| 135 |
-
|
| 136 |
-

|
| 137 |
-
|
| 138 |
-
## 8. Hybrid Pipeline: XGBoost vs Geometric
|
| 139 |
-
|
| 140 |
-
Same grid over w_xgb in {0.3 ... 0.8}. w_geo = 1 - w_xgb.
|
| 141 |
-
|
| 142 |
-
| XGBoost Weight (w_xgb) | Mean LOPO F1 |
|
| 143 |
-
|-----------------------:|-------------:|
|
| 144 |
-
| 0.3 | 0.8639 **<-- selected** |
|
| 145 |
-
| 0.4 | 0.8552 |
|
| 146 |
-
| 0.5 | 0.8451 |
|
| 147 |
-
| 0.6 | 0.8419 |
|
| 148 |
-
| 0.7 | 0.8382 |
|
| 149 |
-
| 0.8 | 0.8353 |
|
| 150 |
-
|
| 151 |
-
**Best:** w_xgb = 0.3 → mean LOPO F1 = 0.8639
|
| 152 |
-
|
| 153 |
-

|
| 154 |
-
|
| 155 |
-
### Which hybrid is used in the app?
|
| 156 |
-
|
| 157 |
-
**XGBoost hybrid is better** (F1 = 0.8639 vs MLP hybrid F1 = 0.8409).
|
| 158 |
-
|
| 159 |
-
### Logistic regression combiner (replaces heuristic weights)
|
| 160 |
-
|
| 161 |
-
Instead of a fixed linear blend (e.g. 0.3·ML + 0.7·geo), a **logistic regression** combines model probability and geometric score: meta-features = [model_prob, geo_score], trained on the same LOPO splits. Threshold from Youden's J on combiner output.
|
| 162 |
-
|
| 163 |
-
| Method | Mean LOPO F1 |
|
| 164 |
-
|--------|-------------:|
|
| 165 |
-
| Heuristic weight grid (best w) | 0.8639 |
|
| 166 |
-
| **LR combiner** | **0.8241** |
|
| 167 |
|
| 168 |
-
|
| 169 |
|
| 170 |
-
##
|
| 171 |
|
| 172 |
### EAR (Eye Aspect Ratio)
|
| 173 |
|
|
@@ -193,7 +87,7 @@ Between 0.16 and 0.30 the `_ear_score` function linearly interpolates from 0 to
|
|
| 193 |
|
| 194 |

|
| 195 |
|
| 196 |
-
##
|
| 197 |
|
| 198 |
| Constant | Value | Rationale |
|
| 199 |
|----------|------:|-----------|
|
|
|
|
| 2 |
|
| 3 |
Auto-generated by `evaluation/justify_thresholds.py` using LOPO cross-validation over 9 participants (~145k samples).
|
| 4 |
|
| 5 |
+
## 0. Latest random split checkpoints (15% test split)
|
| 6 |
+
|
| 7 |
+
From the latest training runs:
|
| 8 |
+
|
| 9 |
+
| Model | Accuracy | F1 | ROC-AUC |
|
| 10 |
+
|-------|----------|-----|---------|
|
| 11 |
+
| XGBoost | 95.87% | 0.9585 | 0.9908 |
|
| 12 |
+
| MLP | 92.92% | 0.9287 | 0.9714 |
|
| 13 |
+
|
| 14 |
## 1. ML Model Decision Thresholds
|
| 15 |
|
| 16 |
+
XGBoost config used for this report: `{'n_estimators': 600, 'max_depth': 8, 'learning_rate': 0.1489, 'subsample': 0.9625, 'colsample_bytree': 0.9013, 'reg_alpha': 1.1407, 'reg_lambda': 2.4181, 'eval_metric': 'logloss'}`.
|
| 17 |
+
|
| 18 |
Thresholds selected via **Youden's J statistic** (J = sensitivity + specificity - 1) on pooled LOPO held-out predictions.
|
| 19 |
|
| 20 |
| Model | LOPO AUC | Optimal Threshold (Youden's J) | F1 @ Optimal | F1 @ 0.50 |
|
| 21 |
|-------|----------|-------------------------------|--------------|-----------|
|
| 22 |
| MLP | 0.8624 | **0.228** | 0.8578 | 0.8149 |
|
| 23 |
+
| XGBoost | 0.8695 | **0.280** | 0.8549 | 0.8324 |
|
| 24 |
|
| 25 |

|
| 26 |
|
| 27 |

|
| 28 |
|
| 29 |
+
## 2. Geometric Pipeline Weights (s_face vs s_eye)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
Grid search over face weight alpha in {0.2 ... 0.8}. Eye weight = 1 - alpha. Threshold per fold via Youden's J.
|
| 32 |
|
|
|
|
| 44 |
|
| 45 |

|
| 46 |
|
| 47 |
+
## 3. Hybrid Pipeline Weights (MLP vs Geometric)
|
| 48 |
|
| 49 |
+
Grid search over w_mlp in {0.3 ... 0.8}. w_geo = 1 - w_mlp. Geometric sub-score uses same weights as geometric pipeline (face=0.7, eye=0.3). If you change geometric weights, re-run this script — optimal w_mlp can shift.
|
| 50 |
|
| 51 |
| MLP Weight (w_mlp) | Mean LOPO F1 |
|
| 52 |
|-------------------:|-------------:|
|
|
|
|
| 57 |
| 0.7 | 0.8039 |
|
| 58 |
| 0.8 | 0.8016 |
|
| 59 |
|
| 60 |
+
**Best:** w_mlp = 0.3 (MLP 30%, geometric 70%)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
+

|
| 63 |
|
| 64 |
+
## 4. Eye and Mouth Aspect Ratio Thresholds
|
| 65 |
|
| 66 |
### EAR (Eye Aspect Ratio)
|
| 67 |
|
|
|
|
| 87 |
|
| 88 |

|
| 89 |
|
| 90 |
+
## 5. Other Constants
|
| 91 |
|
| 92 |
| Constant | Value | Rationale |
|
| 93 |
|----------|------:|-----------|
|
evaluation/feature_importance.py
CHANGED
|
@@ -10,6 +10,7 @@ Outputs:
|
|
| 10 |
|
| 11 |
import os
|
| 12 |
import sys
|
|
|
|
| 13 |
|
| 14 |
import numpy as np
|
| 15 |
from sklearn.preprocessing import StandardScaler
|
|
@@ -20,9 +21,10 @@ _PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
|
| 20 |
if _PROJECT_ROOT not in sys.path:
|
| 21 |
sys.path.insert(0, _PROJECT_ROOT)
|
| 22 |
|
| 23 |
-
from data_preparation.prepare_dataset import load_per_person, SELECTED_FEATURES
|
|
|
|
| 24 |
|
| 25 |
-
SEED =
|
| 26 |
FEATURES = SELECTED_FEATURES["face_orientation"]
|
| 27 |
|
| 28 |
|
|
@@ -45,14 +47,22 @@ def xgb_feature_importance():
|
|
| 45 |
return dict(zip(FEATURES, order))
|
| 46 |
|
| 47 |
|
| 48 |
-
def
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
|
|
|
|
|
|
|
|
|
|
| 54 |
results = {}
|
| 55 |
for drop_feat in FEATURES:
|
|
|
|
| 56 |
idx_keep = [i for i, f in enumerate(FEATURES) if f != drop_feat]
|
| 57 |
f1s = []
|
| 58 |
for held_out in persons:
|
|
@@ -66,13 +76,7 @@ def run_ablation_lopo():
|
|
| 66 |
X_tr_sc = scaler.transform(X_tr)
|
| 67 |
X_te_sc = scaler.transform(X_te)
|
| 68 |
|
| 69 |
-
xgb =
|
| 70 |
-
n_estimators=600, max_depth=8, learning_rate=0.05,
|
| 71 |
-
subsample=0.8, colsample_bytree=0.8,
|
| 72 |
-
reg_alpha=0.1, reg_lambda=1.0,
|
| 73 |
-
eval_metric="logloss",
|
| 74 |
-
random_state=SEED, verbosity=0,
|
| 75 |
-
)
|
| 76 |
xgb.fit(X_tr_sc, train_y)
|
| 77 |
pred = xgb.predict(X_te_sc)
|
| 78 |
f1s.append(f1_score(y_test, pred, average="weighted"))
|
|
@@ -80,10 +84,8 @@ def run_ablation_lopo():
|
|
| 80 |
return results
|
| 81 |
|
| 82 |
|
| 83 |
-
def run_baseline_lopo_f1():
|
| 84 |
"""Full 10-feature LOPO mean F1 for reference."""
|
| 85 |
-
by_person, _, _ = load_per_person("face_orientation")
|
| 86 |
-
persons = sorted(by_person.keys())
|
| 87 |
f1s = []
|
| 88 |
for held_out in persons:
|
| 89 |
train_X = np.concatenate([by_person[p][0] for p in persons if p != held_out])
|
|
@@ -92,13 +94,7 @@ def run_baseline_lopo_f1():
|
|
| 92 |
scaler = StandardScaler().fit(train_X)
|
| 93 |
X_tr_sc = scaler.transform(train_X)
|
| 94 |
X_te_sc = scaler.transform(X_test)
|
| 95 |
-
xgb =
|
| 96 |
-
n_estimators=600, max_depth=8, learning_rate=0.05,
|
| 97 |
-
subsample=0.8, colsample_bytree=0.8,
|
| 98 |
-
reg_alpha=0.1, reg_lambda=1.0,
|
| 99 |
-
eval_metric="logloss",
|
| 100 |
-
random_state=SEED, verbosity=0,
|
| 101 |
-
)
|
| 102 |
xgb.fit(X_tr_sc, train_y)
|
| 103 |
pred = xgb.predict(X_te_sc)
|
| 104 |
f1s.append(f1_score(y_test, pred, average="weighted"))
|
|
@@ -113,12 +109,11 @@ CHANNEL_SUBSETS = {
|
|
| 113 |
}
|
| 114 |
|
| 115 |
|
| 116 |
-
def run_channel_ablation():
|
| 117 |
"""LOPO XGBoost with head-only, eye-only, gaze-only, and all 10. Returns dict subset_name -> mean F1."""
|
| 118 |
-
by_person, _, _ = load_per_person("face_orientation")
|
| 119 |
-
persons = sorted(by_person.keys())
|
| 120 |
results = {}
|
| 121 |
for subset_name, feat_list in CHANNEL_SUBSETS.items():
|
|
|
|
| 122 |
idx_keep = [FEATURES.index(f) for f in feat_list]
|
| 123 |
f1s = []
|
| 124 |
for held_out in persons:
|
|
@@ -130,24 +125,40 @@ def run_channel_ablation():
|
|
| 130 |
scaler = StandardScaler().fit(X_tr)
|
| 131 |
X_tr_sc = scaler.transform(X_tr)
|
| 132 |
X_te_sc = scaler.transform(X_te)
|
| 133 |
-
xgb =
|
| 134 |
-
n_estimators=600, max_depth=8, learning_rate=0.05,
|
| 135 |
-
subsample=0.8, colsample_bytree=0.8,
|
| 136 |
-
reg_alpha=0.1, reg_lambda=1.0,
|
| 137 |
-
eval_metric="logloss",
|
| 138 |
-
random_state=SEED, verbosity=0,
|
| 139 |
-
)
|
| 140 |
xgb.fit(X_tr_sc, train_y)
|
| 141 |
pred = xgb.predict(X_te_sc)
|
| 142 |
f1s.append(f1_score(y_test, pred, average="weighted"))
|
| 143 |
results[subset_name] = np.mean(f1s)
|
| 144 |
-
baseline = run_baseline_lopo_f1()
|
| 145 |
results["all_10"] = baseline
|
| 146 |
return results
|
| 147 |
|
| 148 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
def main():
|
|
|
|
| 150 |
print("=== Feature importance (XGBoost gain) ===")
|
|
|
|
|
|
|
| 151 |
imp = xgb_feature_importance()
|
| 152 |
if imp:
|
| 153 |
for name in FEATURES:
|
|
@@ -155,20 +166,37 @@ def main():
|
|
| 155 |
order = sorted(imp.items(), key=lambda x: -x[1])
|
| 156 |
print(" Top-5 by gain:", [x[0] for x in order[:5]])
|
| 157 |
|
| 158 |
-
print("\n
|
| 159 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
print(f" Baseline (all 10 features) mean LOPO F1: {baseline:.4f}")
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
|
| 173 |
out_dir = os.path.join(_PROJECT_ROOT, "evaluation")
|
| 174 |
out_path = os.path.join(out_dir, "feature_selection_justification.md")
|
|
@@ -188,6 +216,9 @@ def main():
|
|
| 188 |
"",
|
| 189 |
"## 2. XGBoost feature importance (gain)",
|
| 190 |
"",
|
|
|
|
|
|
|
|
|
|
| 191 |
"From the trained XGBoost checkpoint (gain on the 10 features):",
|
| 192 |
"",
|
| 193 |
"| Feature | Gain |",
|
|
@@ -207,19 +238,37 @@ def main():
|
|
| 207 |
"",
|
| 208 |
f"Baseline (all 10 features) mean LOPO F1: **{baseline:.4f}**.",
|
| 209 |
"",
|
| 210 |
-
"| Feature dropped | Mean LOPO F1 | Δ vs baseline |",
|
| 211 |
-
"|------------------|--------------|---------------|",
|
| 212 |
])
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
lines.append("")
|
| 218 |
-
lines.append(
|
| 219 |
lines.append("")
|
| 220 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 221 |
lines.append("")
|
| 222 |
-
|
|
|
|
|
|
|
|
|
|
| 223 |
lines.append("")
|
| 224 |
with open(out_path, "w", encoding="utf-8") as f:
|
| 225 |
f.write("\n".join(lines))
|
|
|
|
| 10 |
|
| 11 |
import os
|
| 12 |
import sys
|
| 13 |
+
import argparse
|
| 14 |
|
| 15 |
import numpy as np
|
| 16 |
from sklearn.preprocessing import StandardScaler
|
|
|
|
| 21 |
if _PROJECT_ROOT not in sys.path:
|
| 22 |
sys.path.insert(0, _PROJECT_ROOT)
|
| 23 |
|
| 24 |
+
from data_preparation.prepare_dataset import get_default_split_config, load_per_person, SELECTED_FEATURES
|
| 25 |
+
from models.xgboost.config import XGB_BASE_PARAMS, build_xgb_classifier, get_xgb_params
|
| 26 |
|
| 27 |
+
_, SEED = get_default_split_config()
|
| 28 |
FEATURES = SELECTED_FEATURES["face_orientation"]
|
| 29 |
|
| 30 |
|
|
|
|
| 47 |
return dict(zip(FEATURES, order))
|
| 48 |
|
| 49 |
|
| 50 |
+
def _make_eval_model(seed: int, quick: bool):
|
| 51 |
+
if not quick:
|
| 52 |
+
return build_xgb_classifier(seed, verbosity=0)
|
| 53 |
+
|
| 54 |
+
params = get_xgb_params()
|
| 55 |
+
params["n_estimators"] = 200
|
| 56 |
+
params["random_state"] = seed
|
| 57 |
+
params["verbosity"] = 0
|
| 58 |
+
return XGBClassifier(**params)
|
| 59 |
|
| 60 |
+
|
| 61 |
+
def run_ablation_lopo(by_person, persons, quick: bool):
|
| 62 |
+
"""Leave-one-feature-out: for each feature, train XGBoost on the other 9 with LOPO, report mean F1."""
|
| 63 |
results = {}
|
| 64 |
for drop_feat in FEATURES:
|
| 65 |
+
print(f" -> dropping {drop_feat} ({len(results)+1}/{len(FEATURES)})")
|
| 66 |
idx_keep = [i for i, f in enumerate(FEATURES) if f != drop_feat]
|
| 67 |
f1s = []
|
| 68 |
for held_out in persons:
|
|
|
|
| 76 |
X_tr_sc = scaler.transform(X_tr)
|
| 77 |
X_te_sc = scaler.transform(X_te)
|
| 78 |
|
| 79 |
+
xgb = _make_eval_model(SEED, quick)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
xgb.fit(X_tr_sc, train_y)
|
| 81 |
pred = xgb.predict(X_te_sc)
|
| 82 |
f1s.append(f1_score(y_test, pred, average="weighted"))
|
|
|
|
| 84 |
return results
|
| 85 |
|
| 86 |
|
| 87 |
+
def run_baseline_lopo_f1(by_person, persons, quick: bool):
|
| 88 |
"""Full 10-feature LOPO mean F1 for reference."""
|
|
|
|
|
|
|
| 89 |
f1s = []
|
| 90 |
for held_out in persons:
|
| 91 |
train_X = np.concatenate([by_person[p][0] for p in persons if p != held_out])
|
|
|
|
| 94 |
scaler = StandardScaler().fit(train_X)
|
| 95 |
X_tr_sc = scaler.transform(train_X)
|
| 96 |
X_te_sc = scaler.transform(X_test)
|
| 97 |
+
xgb = _make_eval_model(SEED, quick)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
xgb.fit(X_tr_sc, train_y)
|
| 99 |
pred = xgb.predict(X_te_sc)
|
| 100 |
f1s.append(f1_score(y_test, pred, average="weighted"))
|
|
|
|
| 109 |
}
|
| 110 |
|
| 111 |
|
| 112 |
+
def run_channel_ablation(by_person, persons, quick: bool, baseline: float):
|
| 113 |
"""LOPO XGBoost with head-only, eye-only, gaze-only, and all 10. Returns dict subset_name -> mean F1."""
|
|
|
|
|
|
|
| 114 |
results = {}
|
| 115 |
for subset_name, feat_list in CHANNEL_SUBSETS.items():
|
| 116 |
+
print(f" -> channel {subset_name}")
|
| 117 |
idx_keep = [FEATURES.index(f) for f in feat_list]
|
| 118 |
f1s = []
|
| 119 |
for held_out in persons:
|
|
|
|
| 125 |
scaler = StandardScaler().fit(X_tr)
|
| 126 |
X_tr_sc = scaler.transform(X_tr)
|
| 127 |
X_te_sc = scaler.transform(X_te)
|
| 128 |
+
xgb = _make_eval_model(SEED, quick)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
xgb.fit(X_tr_sc, train_y)
|
| 130 |
pred = xgb.predict(X_te_sc)
|
| 131 |
f1s.append(f1_score(y_test, pred, average="weighted"))
|
| 132 |
results[subset_name] = np.mean(f1s)
|
|
|
|
| 133 |
results["all_10"] = baseline
|
| 134 |
return results
|
| 135 |
|
| 136 |
|
| 137 |
+
def _parse_args():
|
| 138 |
+
parser = argparse.ArgumentParser(description="Feature importance + LOPO ablation")
|
| 139 |
+
parser.add_argument(
|
| 140 |
+
"--quick",
|
| 141 |
+
action="store_true",
|
| 142 |
+
help="Use fewer trees (200) for faster iteration.",
|
| 143 |
+
)
|
| 144 |
+
parser.add_argument(
|
| 145 |
+
"--skip-lofo",
|
| 146 |
+
action="store_true",
|
| 147 |
+
help="Skip leave-one-feature-out ablation.",
|
| 148 |
+
)
|
| 149 |
+
parser.add_argument(
|
| 150 |
+
"--skip-channel",
|
| 151 |
+
action="store_true",
|
| 152 |
+
help="Skip channel ablation.",
|
| 153 |
+
)
|
| 154 |
+
return parser.parse_args()
|
| 155 |
+
|
| 156 |
+
|
| 157 |
def main():
|
| 158 |
+
args = _parse_args()
|
| 159 |
print("=== Feature importance (XGBoost gain) ===")
|
| 160 |
+
if args.quick:
|
| 161 |
+
print("Running in quick mode (n_estimators=200).")
|
| 162 |
imp = xgb_feature_importance()
|
| 163 |
if imp:
|
| 164 |
for name in FEATURES:
|
|
|
|
| 166 |
order = sorted(imp.items(), key=lambda x: -x[1])
|
| 167 |
print(" Top-5 by gain:", [x[0] for x in order[:5]])
|
| 168 |
|
| 169 |
+
print("\n[DATA] Loading per-person splits once...")
|
| 170 |
+
by_person, _, _ = load_per_person("face_orientation")
|
| 171 |
+
persons = sorted(by_person.keys())
|
| 172 |
+
|
| 173 |
+
print("\n=== Baseline LOPO (all 10 features) ===")
|
| 174 |
+
baseline = run_baseline_lopo_f1(by_person, persons, quick=args.quick)
|
| 175 |
print(f" Baseline (all 10 features) mean LOPO F1: {baseline:.4f}")
|
| 176 |
+
|
| 177 |
+
ablation = None
|
| 178 |
+
worst_drop = None
|
| 179 |
+
if args.skip_lofo:
|
| 180 |
+
print("\n=== Leave-one-feature-out ablation (LOPO mean F1) ===")
|
| 181 |
+
print(" skipped (--skip-lofo)")
|
| 182 |
+
else:
|
| 183 |
+
print("\n=== Leave-one-feature-out ablation (LOPO mean F1) ===")
|
| 184 |
+
ablation = run_ablation_lopo(by_person, persons, quick=args.quick)
|
| 185 |
+
for feat in FEATURES:
|
| 186 |
+
delta = baseline - ablation[feat]
|
| 187 |
+
print(f" drop {feat}: F1={ablation[feat]:.4f} (Δ={delta:+.4f})")
|
| 188 |
+
worst_drop = min(ablation.items(), key=lambda x: x[1])
|
| 189 |
+
print(f" Largest F1 drop when dropping: {worst_drop[0]} (F1={worst_drop[1]:.4f})")
|
| 190 |
+
|
| 191 |
+
channel_f1 = None
|
| 192 |
+
if args.skip_channel:
|
| 193 |
+
print("\n=== Channel ablation (LOPO mean F1) ===")
|
| 194 |
+
print(" skipped (--skip-channel)")
|
| 195 |
+
else:
|
| 196 |
+
print("\n=== Channel ablation (LOPO mean F1) ===")
|
| 197 |
+
channel_f1 = run_channel_ablation(by_person, persons, quick=args.quick, baseline=baseline)
|
| 198 |
+
for name, f1 in channel_f1.items():
|
| 199 |
+
print(f" {name}: {f1:.4f}")
|
| 200 |
|
| 201 |
out_dir = os.path.join(_PROJECT_ROOT, "evaluation")
|
| 202 |
out_path = os.path.join(out_dir, "feature_selection_justification.md")
|
|
|
|
| 216 |
"",
|
| 217 |
"## 2. XGBoost feature importance (gain)",
|
| 218 |
"",
|
| 219 |
+
f"Config used: `{XGB_BASE_PARAMS}`.",
|
| 220 |
+
"Quick mode: " + ("yes (200 trees)" if args.quick else "no (full config)"),
|
| 221 |
+
"",
|
| 222 |
"From the trained XGBoost checkpoint (gain on the 10 features):",
|
| 223 |
"",
|
| 224 |
"| Feature | Gain |",
|
|
|
|
| 238 |
"",
|
| 239 |
f"Baseline (all 10 features) mean LOPO F1: **{baseline:.4f}**.",
|
| 240 |
"",
|
|
|
|
|
|
|
| 241 |
])
|
| 242 |
+
if ablation is None:
|
| 243 |
+
lines.append("Skipped in this run (`--skip-lofo`).")
|
| 244 |
+
else:
|
| 245 |
+
lines.extend([
|
| 246 |
+
"| Feature dropped | Mean LOPO F1 | Δ vs baseline |",
|
| 247 |
+
"|------------------|--------------|---------------|",
|
| 248 |
+
])
|
| 249 |
+
for feat in FEATURES:
|
| 250 |
+
delta = baseline - ablation[feat]
|
| 251 |
+
lines.append(f"| {feat} | {ablation[feat]:.4f} | {delta:+.4f} |")
|
| 252 |
+
lines.append("")
|
| 253 |
+
lines.append(f"Dropping **{worst_drop[0]}** hurts most (F1={worst_drop[1]:.4f}), consistent with it being important.")
|
| 254 |
+
|
| 255 |
lines.append("")
|
| 256 |
+
lines.append("## 4. Channel ablation (LOPO)")
|
| 257 |
lines.append("")
|
| 258 |
+
if channel_f1 is None:
|
| 259 |
+
lines.append("Skipped in this run (`--skip-channel`).")
|
| 260 |
+
else:
|
| 261 |
+
lines.append("| Subset | Mean LOPO F1 |")
|
| 262 |
+
lines.append("|--------|--------------|")
|
| 263 |
+
for name in ["head_pose", "eye_state", "gaze", "all_10"]:
|
| 264 |
+
lines.append(f"| {name} | {channel_f1[name]:.4f} |")
|
| 265 |
+
lines.append("")
|
| 266 |
+
lines.append("## 5. Conclusion")
|
| 267 |
lines.append("")
|
| 268 |
+
if ablation is None:
|
| 269 |
+
lines.append("Selection is supported by (1) domain rationale (three attention channels), (2) XGBoost gain importance, and (3) channel ablation. Run without `--skip-lofo` for full leave-one-out ablation.")
|
| 270 |
+
else:
|
| 271 |
+
lines.append("Selection is supported by (1) domain rationale (three attention channels), (2) XGBoost gain importance, and (3) leave-one-out ablation. SHAP or correlation-based pruning can be added in future work.")
|
| 272 |
lines.append("")
|
| 273 |
with open(out_path, "w", encoding="utf-8") as f:
|
| 274 |
f.write("\n".join(lines))
|
evaluation/feature_selection_justification.md
CHANGED
|
@@ -13,6 +13,9 @@ Excluded: v_gaze (noisy), mar (rare events), yaw/roll (redundant with head_devia
|
|
| 13 |
|
| 14 |
## 2. XGBoost feature importance (gain)
|
| 15 |
|
|
|
|
|
|
|
|
|
|
| 16 |
From the trained XGBoost checkpoint (gain on the 10 features):
|
| 17 |
|
| 18 |
| Feature | Gain |
|
|
@@ -32,23 +35,19 @@ From the trained XGBoost checkpoint (gain on the 10 features):
|
|
| 32 |
|
| 33 |
## 3. Leave-one-feature-out ablation (LOPO)
|
| 34 |
|
| 35 |
-
Baseline (all 10 features) mean LOPO F1: **0.
|
|
|
|
|
|
|
| 36 |
|
| 37 |
-
|
| 38 |
-
|------------------|--------------|---------------|
|
| 39 |
-
| head_deviation | 0.8395 | -0.0068 |
|
| 40 |
-
| s_face | 0.8390 | -0.0063 |
|
| 41 |
-
| s_eye | 0.8342 | -0.0015 |
|
| 42 |
-
| h_gaze | 0.8244 | +0.0083 |
|
| 43 |
-
| pitch | 0.8250 | +0.0077 |
|
| 44 |
-
| ear_left | 0.8326 | +0.0001 |
|
| 45 |
-
| ear_avg | 0.8350 | -0.0023 |
|
| 46 |
-
| ear_right | 0.8344 | -0.0017 |
|
| 47 |
-
| gaze_offset | 0.8351 | -0.0024 |
|
| 48 |
-
| perclos | 0.8258 | +0.0069 |
|
| 49 |
|
| 50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
|
| 52 |
-
##
|
| 53 |
|
| 54 |
-
Selection is supported by (1) domain rationale (three attention channels), (2) XGBoost gain importance, and (3)
|
|
|
|
| 13 |
|
| 14 |
## 2. XGBoost feature importance (gain)
|
| 15 |
|
| 16 |
+
Config used: `{'n_estimators': 600, 'max_depth': 8, 'learning_rate': 0.1489, 'subsample': 0.9625, 'colsample_bytree': 0.9013, 'reg_alpha': 1.1407, 'reg_lambda': 2.4181, 'eval_metric': 'logloss'}`.
|
| 17 |
+
Quick mode: yes (200 trees)
|
| 18 |
+
|
| 19 |
From the trained XGBoost checkpoint (gain on the 10 features):
|
| 20 |
|
| 21 |
| Feature | Gain |
|
|
|
|
| 35 |
|
| 36 |
## 3. Leave-one-feature-out ablation (LOPO)
|
| 37 |
|
| 38 |
+
Baseline (all 10 features) mean LOPO F1: **0.8286**.
|
| 39 |
+
|
| 40 |
+
Skipped in this run (`--skip-lofo`).
|
| 41 |
|
| 42 |
+
## 4. Channel ablation (LOPO)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
+
| Subset | Mean LOPO F1 |
|
| 45 |
+
|--------|--------------|
|
| 46 |
+
| head_pose | 0.7480 |
|
| 47 |
+
| eye_state | 0.8071 |
|
| 48 |
+
| gaze | 0.7260 |
|
| 49 |
+
| all_10 | 0.8286 |
|
| 50 |
|
| 51 |
+
## 5. Conclusion
|
| 52 |
|
| 53 |
+
Selection is supported by (1) domain rationale (three attention channels), (2) XGBoost gain importance, and (3) channel ablation. Run without `--skip-lofo` for full leave-one-out ablation.
|
evaluation/grouped_split_benchmark.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Compare pooled random split vs grouped LOPO for XGBoost."""
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score
|
| 8 |
+
|
| 9 |
+
_PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
| 10 |
+
if _PROJECT_ROOT not in sys.path:
|
| 11 |
+
sys.path.insert(0, _PROJECT_ROOT)
|
| 12 |
+
|
| 13 |
+
from data_preparation.prepare_dataset import get_default_split_config, get_numpy_splits, load_per_person
|
| 14 |
+
from models.xgboost.config import build_xgb_classifier
|
| 15 |
+
|
| 16 |
+
MODEL_NAME = "face_orientation"
|
| 17 |
+
OUT_PATH = os.path.join(_PROJECT_ROOT, "evaluation", "GROUPED_SPLIT_BENCHMARK.md")
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def run_pooled_split():
|
| 21 |
+
split_ratios, seed = get_default_split_config()
|
| 22 |
+
splits, _, _, _ = get_numpy_splits(
|
| 23 |
+
model_name=MODEL_NAME,
|
| 24 |
+
split_ratios=split_ratios,
|
| 25 |
+
seed=seed,
|
| 26 |
+
scale=False,
|
| 27 |
+
)
|
| 28 |
+
model = build_xgb_classifier(seed, verbosity=0, early_stopping_rounds=30)
|
| 29 |
+
model.fit(
|
| 30 |
+
splits["X_train"],
|
| 31 |
+
splits["y_train"],
|
| 32 |
+
eval_set=[(splits["X_val"], splits["y_val"])],
|
| 33 |
+
verbose=False,
|
| 34 |
+
)
|
| 35 |
+
probs = model.predict_proba(splits["X_test"])[:, 1]
|
| 36 |
+
preds = (probs >= 0.5).astype(int)
|
| 37 |
+
y = splits["y_test"]
|
| 38 |
+
return {
|
| 39 |
+
"accuracy": float(accuracy_score(y, preds)),
|
| 40 |
+
"f1": float(f1_score(y, preds, average="weighted")),
|
| 41 |
+
"auc": float(roc_auc_score(y, probs)),
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def run_grouped_lopo():
|
| 46 |
+
by_person, _, _ = load_per_person(MODEL_NAME)
|
| 47 |
+
persons = sorted(by_person.keys())
|
| 48 |
+
scores = {"accuracy": [], "f1": [], "auc": []}
|
| 49 |
+
|
| 50 |
+
_, seed = get_default_split_config()
|
| 51 |
+
for held_out in persons:
|
| 52 |
+
train_x = np.concatenate([by_person[p][0] for p in persons if p != held_out], axis=0)
|
| 53 |
+
train_y = np.concatenate([by_person[p][1] for p in persons if p != held_out], axis=0)
|
| 54 |
+
test_x, test_y = by_person[held_out]
|
| 55 |
+
|
| 56 |
+
model = build_xgb_classifier(seed, verbosity=0)
|
| 57 |
+
model.fit(train_x, train_y, verbose=False)
|
| 58 |
+
probs = model.predict_proba(test_x)[:, 1]
|
| 59 |
+
preds = (probs >= 0.5).astype(int)
|
| 60 |
+
|
| 61 |
+
scores["accuracy"].append(float(accuracy_score(test_y, preds)))
|
| 62 |
+
scores["f1"].append(float(f1_score(test_y, preds, average="weighted")))
|
| 63 |
+
scores["auc"].append(float(roc_auc_score(test_y, probs)))
|
| 64 |
+
|
| 65 |
+
return {
|
| 66 |
+
"accuracy": float(np.mean(scores["accuracy"])),
|
| 67 |
+
"f1": float(np.mean(scores["f1"])),
|
| 68 |
+
"auc": float(np.mean(scores["auc"])),
|
| 69 |
+
"folds": len(persons),
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def write_report(pooled, grouped):
|
| 74 |
+
lines = [
|
| 75 |
+
"# Grouped vs pooled split benchmark",
|
| 76 |
+
"",
|
| 77 |
+
"This compares the same XGBoost config under two evaluation protocols.",
|
| 78 |
+
"",
|
| 79 |
+
f"Config: `{XGB_BASE_PARAMS}`",
|
| 80 |
+
"",
|
| 81 |
+
"| Protocol | Accuracy | F1 (weighted) | ROC-AUC |",
|
| 82 |
+
"|----------|---------:|--------------:|--------:|",
|
| 83 |
+
f"| Pooled random split (70/15/15) | {pooled['accuracy']:.4f} | {pooled['f1']:.4f} | {pooled['auc']:.4f} |",
|
| 84 |
+
f"| Grouped LOPO ({grouped['folds']} folds) | {grouped['accuracy']:.4f} | {grouped['f1']:.4f} | {grouped['auc']:.4f} |",
|
| 85 |
+
"",
|
| 86 |
+
"Use grouped LOPO as the primary generalisation metric when reporting model quality.",
|
| 87 |
+
"",
|
| 88 |
+
]
|
| 89 |
+
|
| 90 |
+
with open(OUT_PATH, "w", encoding="utf-8") as f:
|
| 91 |
+
f.write("\n".join(lines))
|
| 92 |
+
print(f"[LOG] Wrote {OUT_PATH}")
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def main():
|
| 96 |
+
pooled = run_pooled_split()
|
| 97 |
+
grouped = run_grouped_lopo()
|
| 98 |
+
write_report(pooled, grouped)
|
| 99 |
+
print(
|
| 100 |
+
"[DONE] pooled_f1={:.4f} grouped_f1={:.4f}".format(
|
| 101 |
+
pooled["f1"], grouped["f1"]
|
| 102 |
+
)
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
if __name__ == "__main__":
|
| 107 |
+
main()
|
evaluation/justify_thresholds.py
CHANGED
|
@@ -12,20 +12,19 @@ import matplotlib.pyplot as plt
|
|
| 12 |
from sklearn.neural_network import MLPClassifier
|
| 13 |
from sklearn.preprocessing import StandardScaler
|
| 14 |
from sklearn.metrics import roc_curve, roc_auc_score, f1_score
|
| 15 |
-
from xgboost import XGBClassifier
|
| 16 |
|
| 17 |
_PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
| 18 |
sys.path.insert(0, _PROJECT_ROOT)
|
| 19 |
|
| 20 |
-
from data_preparation.prepare_dataset import load_per_person, SELECTED_FEATURES
|
|
|
|
| 21 |
|
| 22 |
PLOTS_DIR = os.path.join(os.path.dirname(__file__), "plots")
|
| 23 |
REPORT_PATH = os.path.join(os.path.dirname(__file__), "THRESHOLD_JUSTIFICATION.md")
|
| 24 |
-
SEED =
|
| 25 |
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
_USE_CLEARML = os.environ.get("USE_CLEARML", "0") == "1" or "--clearml" in sys.argv
|
| 29 |
|
| 30 |
_task = None
|
| 31 |
_logger = None
|
|
@@ -33,13 +32,21 @@ _logger = None
|
|
| 33 |
if _USE_CLEARML:
|
| 34 |
try:
|
| 35 |
from clearml import Task
|
|
|
|
| 36 |
_task = Task.init(
|
| 37 |
project_name="Focus Guard",
|
| 38 |
task_name="Threshold Justification",
|
| 39 |
tags=["evaluation", "thresholds"],
|
| 40 |
)
|
| 41 |
-
|
|
|
|
|
|
|
|
|
|
| 42 |
_logger = _task.get_logger()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
print("ClearML enabled — logging to project 'Focus Guard'")
|
| 44 |
except ImportError:
|
| 45 |
print("WARNING: ClearML not installed. Continuing without logging.")
|
|
@@ -107,13 +114,7 @@ def run_lopo_models():
|
|
| 107 |
results["mlp"]["y"].append(y_test)
|
| 108 |
results["mlp"]["p"].append(mlp_prob)
|
| 109 |
|
| 110 |
-
xgb =
|
| 111 |
-
n_estimators=600, max_depth=8, learning_rate=0.05,
|
| 112 |
-
subsample=0.8, colsample_bytree=0.8,
|
| 113 |
-
reg_alpha=0.1, reg_lambda=1.0,
|
| 114 |
-
use_label_encoder=False, eval_metric="logloss",
|
| 115 |
-
random_state=SEED, verbosity=0,
|
| 116 |
-
)
|
| 117 |
xgb.fit(X_tr_sc, train_y)
|
| 118 |
xgb_prob = xgb.predict_proba(X_te_sc)[:, 1]
|
| 119 |
results["xgb"]["y"].append(y_test)
|
|
@@ -422,6 +423,8 @@ def write_report(model_stats, geo_f1, best_alpha, hybrid_f1, best_w, dist_stats)
|
|
| 422 |
|
| 423 |
lines.append("## 1. ML Model Decision Thresholds")
|
| 424 |
lines.append("")
|
|
|
|
|
|
|
| 425 |
lines.append("Thresholds selected via **Youden's J statistic** (J = sensitivity + specificity - 1) "
|
| 426 |
"on pooled LOPO held-out predictions.")
|
| 427 |
lines.append("")
|
|
|
|
| 12 |
from sklearn.neural_network import MLPClassifier
|
| 13 |
from sklearn.preprocessing import StandardScaler
|
| 14 |
from sklearn.metrics import roc_curve, roc_auc_score, f1_score
|
|
|
|
| 15 |
|
| 16 |
_PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
| 17 |
sys.path.insert(0, _PROJECT_ROOT)
|
| 18 |
|
| 19 |
+
from data_preparation.prepare_dataset import get_default_split_config, load_per_person, SELECTED_FEATURES
|
| 20 |
+
from models.xgboost.config import XGB_BASE_PARAMS, build_xgb_classifier
|
| 21 |
|
| 22 |
PLOTS_DIR = os.path.join(os.path.dirname(__file__), "plots")
|
| 23 |
REPORT_PATH = os.path.join(os.path.dirname(__file__), "THRESHOLD_JUSTIFICATION.md")
|
| 24 |
+
_, SEED = get_default_split_config()
|
| 25 |
|
| 26 |
+
_USE_CLEARML = os.environ.get("USE_CLEARML", "0") == "1" or "--clearml" in sys.argv or bool(os.environ.get("CLEARML_TASK_ID"))
|
| 27 |
+
_CLEARML_QUEUE = os.environ.get("CLEARML_QUEUE", "")
|
|
|
|
| 28 |
|
| 29 |
_task = None
|
| 30 |
_logger = None
|
|
|
|
| 32 |
if _USE_CLEARML:
|
| 33 |
try:
|
| 34 |
from clearml import Task
|
| 35 |
+
from config import flatten_for_clearml
|
| 36 |
_task = Task.init(
|
| 37 |
project_name="Focus Guard",
|
| 38 |
task_name="Threshold Justification",
|
| 39 |
tags=["evaluation", "thresholds"],
|
| 40 |
)
|
| 41 |
+
flat = flatten_for_clearml()
|
| 42 |
+
flat["evaluation/SEED"] = SEED
|
| 43 |
+
flat["evaluation/n_participants"] = 9
|
| 44 |
+
_task.connect(flat)
|
| 45 |
_logger = _task.get_logger()
|
| 46 |
+
if _CLEARML_QUEUE:
|
| 47 |
+
print(f"[ClearML] Enqueuing to queue '{_CLEARML_QUEUE}'.")
|
| 48 |
+
_task.execute_remotely(queue_name=_CLEARML_QUEUE)
|
| 49 |
+
sys.exit(0)
|
| 50 |
print("ClearML enabled — logging to project 'Focus Guard'")
|
| 51 |
except ImportError:
|
| 52 |
print("WARNING: ClearML not installed. Continuing without logging.")
|
|
|
|
| 114 |
results["mlp"]["y"].append(y_test)
|
| 115 |
results["mlp"]["p"].append(mlp_prob)
|
| 116 |
|
| 117 |
+
xgb = build_xgb_classifier(SEED, verbosity=0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
xgb.fit(X_tr_sc, train_y)
|
| 119 |
xgb_prob = xgb.predict_proba(X_te_sc)[:, 1]
|
| 120 |
results["xgb"]["y"].append(y_test)
|
|
|
|
| 423 |
|
| 424 |
lines.append("## 1. ML Model Decision Thresholds")
|
| 425 |
lines.append("")
|
| 426 |
+
lines.append(f"XGBoost config used for this report: `{XGB_BASE_PARAMS}`.")
|
| 427 |
+
lines.append("")
|
| 428 |
lines.append("Thresholds selected via **Youden's J statistic** (J = sensitivity + specificity - 1) "
|
| 429 |
"on pooled LOPO held-out predictions.")
|
| 430 |
lines.append("")
|
evaluation/plots/roc_xgb.png
CHANGED
|
|
main.py
CHANGED
|
@@ -1,174 +1,68 @@
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
-
|
| 4 |
-
from fastapi.staticfiles import StaticFiles
|
| 5 |
-
from fastapi.responses import FileResponse
|
| 6 |
-
from fastapi.middleware.cors import CORSMiddleware
|
| 7 |
-
from pydantic import BaseModel
|
| 8 |
-
from typing import Optional, List, Any
|
| 9 |
import base64
|
| 10 |
-
import
|
| 11 |
-
import numpy as np
|
| 12 |
-
import aiosqlite
|
| 13 |
import json
|
| 14 |
-
|
| 15 |
-
import math
|
| 16 |
import os
|
| 17 |
-
from pathlib import Path
|
| 18 |
-
from typing import Callable
|
| 19 |
-
from contextlib import asynccontextmanager
|
| 20 |
-
import asyncio
|
| 21 |
-
import concurrent.futures
|
| 22 |
import threading
|
| 23 |
-
import
|
|
|
|
|
|
|
|
|
|
| 24 |
|
|
|
|
|
|
|
|
|
|
| 25 |
from aiortc import RTCPeerConnection, RTCSessionDescription, VideoStreamTrack
|
| 26 |
-
|
| 27 |
-
logger = logging.getLogger(__name__)
|
| 28 |
from av import VideoFrame
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
-
from
|
| 31 |
-
from
|
| 32 |
-
|
| 33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
)
|
|
|
|
| 35 |
from models.face_mesh import FaceMeshDetector
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
_FONT = cv2.FONT_HERSHEY_SIMPLEX
|
| 40 |
-
_CYAN = (255, 255, 0)
|
| 41 |
-
_GREEN = (0, 255, 0)
|
| 42 |
-
_MAGENTA = (255, 0, 255)
|
| 43 |
-
_ORANGE = (0, 165, 255)
|
| 44 |
_RED = (0, 0, 255)
|
| 45 |
-
|
| 46 |
-
_LIGHT_GREEN = (144, 238, 144)
|
| 47 |
-
|
| 48 |
-
_TESSELATION_CONNS = [(c.start, c.end) for c in FaceLandmarksConnections.FACE_LANDMARKS_TESSELATION]
|
| 49 |
-
_CONTOUR_CONNS = [(c.start, c.end) for c in FaceLandmarksConnections.FACE_LANDMARKS_CONTOURS]
|
| 50 |
-
_LEFT_EYEBROW = [70, 63, 105, 66, 107, 55, 65, 52, 53, 46]
|
| 51 |
-
_RIGHT_EYEBROW = [300, 293, 334, 296, 336, 285, 295, 282, 283, 276]
|
| 52 |
-
_NOSE_BRIDGE = [6, 197, 195, 5, 4, 1, 19, 94, 2]
|
| 53 |
-
_LIPS_OUTER = [61, 146, 91, 181, 84, 17, 314, 405, 321, 375, 291, 409, 270, 269, 267, 0, 37, 39, 40, 185, 61]
|
| 54 |
-
_LIPS_INNER = [78, 95, 88, 178, 87, 14, 317, 402, 318, 324, 308, 415, 310, 311, 312, 13, 82, 81, 80, 191, 78]
|
| 55 |
-
_LEFT_EAR_POINTS = [33, 160, 158, 133, 153, 145]
|
| 56 |
-
_RIGHT_EAR_POINTS = [362, 385, 387, 263, 373, 380]
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
def _lm_px(lm, idx, w, h):
|
| 60 |
-
return (int(lm[idx, 0] * w), int(lm[idx, 1] * h))
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
def _draw_polyline(frame, lm, indices, w, h, color, thickness):
|
| 64 |
-
for i in range(len(indices) - 1):
|
| 65 |
-
cv2.line(frame, _lm_px(lm, indices[i], w, h), _lm_px(lm, indices[i + 1], w, h), color, thickness, cv2.LINE_AA)
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
def _draw_face_mesh(frame, lm, w, h):
|
| 69 |
-
"""Draw tessellation, contours, eyebrows, nose, lips, eyes, irises, gaze lines."""
|
| 70 |
-
# Tessellation (gray triangular grid, semi-transparent)
|
| 71 |
-
overlay = frame.copy()
|
| 72 |
-
for s, e in _TESSELATION_CONNS:
|
| 73 |
-
cv2.line(overlay, _lm_px(lm, s, w, h), _lm_px(lm, e, w, h), (200, 200, 200), 1, cv2.LINE_AA)
|
| 74 |
-
cv2.addWeighted(overlay, 0.3, frame, 0.7, 0, frame)
|
| 75 |
-
# Contours
|
| 76 |
-
for s, e in _CONTOUR_CONNS:
|
| 77 |
-
cv2.line(frame, _lm_px(lm, s, w, h), _lm_px(lm, e, w, h), _CYAN, 1, cv2.LINE_AA)
|
| 78 |
-
# Eyebrows
|
| 79 |
-
_draw_polyline(frame, lm, _LEFT_EYEBROW, w, h, _LIGHT_GREEN, 2)
|
| 80 |
-
_draw_polyline(frame, lm, _RIGHT_EYEBROW, w, h, _LIGHT_GREEN, 2)
|
| 81 |
-
# Nose
|
| 82 |
-
_draw_polyline(frame, lm, _NOSE_BRIDGE, w, h, _ORANGE, 1)
|
| 83 |
-
# Lips
|
| 84 |
-
_draw_polyline(frame, lm, _LIPS_OUTER, w, h, _MAGENTA, 1)
|
| 85 |
-
_draw_polyline(frame, lm, _LIPS_INNER, w, h, (200, 0, 200), 1)
|
| 86 |
-
# Eyes
|
| 87 |
-
left_pts = np.array([_lm_px(lm, i, w, h) for i in FaceMeshDetector.LEFT_EYE_INDICES], dtype=np.int32)
|
| 88 |
-
cv2.polylines(frame, [left_pts], True, _GREEN, 2, cv2.LINE_AA)
|
| 89 |
-
right_pts = np.array([_lm_px(lm, i, w, h) for i in FaceMeshDetector.RIGHT_EYE_INDICES], dtype=np.int32)
|
| 90 |
-
cv2.polylines(frame, [right_pts], True, _GREEN, 2, cv2.LINE_AA)
|
| 91 |
-
# EAR key points
|
| 92 |
-
for indices in [_LEFT_EAR_POINTS, _RIGHT_EAR_POINTS]:
|
| 93 |
-
for idx in indices:
|
| 94 |
-
cv2.circle(frame, _lm_px(lm, idx, w, h), 3, (0, 255, 255), -1, cv2.LINE_AA)
|
| 95 |
-
# Irises + gaze lines
|
| 96 |
-
for iris_idx, eye_inner, eye_outer in [
|
| 97 |
-
(FaceMeshDetector.LEFT_IRIS_INDICES, 133, 33),
|
| 98 |
-
(FaceMeshDetector.RIGHT_IRIS_INDICES, 362, 263),
|
| 99 |
-
]:
|
| 100 |
-
iris_pts = np.array([_lm_px(lm, i, w, h) for i in iris_idx], dtype=np.int32)
|
| 101 |
-
center = iris_pts[0]
|
| 102 |
-
if len(iris_pts) >= 5:
|
| 103 |
-
radii = [np.linalg.norm(iris_pts[j] - center) for j in range(1, 5)]
|
| 104 |
-
radius = max(int(np.mean(radii)), 2)
|
| 105 |
-
cv2.circle(frame, tuple(center), radius, _MAGENTA, 2, cv2.LINE_AA)
|
| 106 |
-
cv2.circle(frame, tuple(center), 2, _WHITE, -1, cv2.LINE_AA)
|
| 107 |
-
eye_cx = int((lm[eye_inner, 0] + lm[eye_outer, 0]) / 2.0 * w)
|
| 108 |
-
eye_cy = int((lm[eye_inner, 1] + lm[eye_outer, 1]) / 2.0 * h)
|
| 109 |
-
dx, dy = center[0] - eye_cx, center[1] - eye_cy
|
| 110 |
-
cv2.line(frame, tuple(center), (int(center[0] + dx * 3), int(center[1] + dy * 3)), _RED, 1, cv2.LINE_AA)
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
def _draw_hud(frame, result, model_name):
|
| 114 |
-
"""Draw status bar and detail overlay matching live_demo.py."""
|
| 115 |
-
h, w = frame.shape[:2]
|
| 116 |
-
is_focused = result["is_focused"]
|
| 117 |
-
status = "FOCUSED" if is_focused else "NOT FOCUSED"
|
| 118 |
-
color = _GREEN if is_focused else _RED
|
| 119 |
-
|
| 120 |
-
# Top bar
|
| 121 |
-
cv2.rectangle(frame, (0, 0), (w, 55), (0, 0, 0), -1)
|
| 122 |
-
cv2.putText(frame, status, (10, 28), _FONT, 0.8, color, 2, cv2.LINE_AA)
|
| 123 |
-
cv2.putText(frame, model_name.upper(), (w - 150, 28), _FONT, 0.45, _WHITE, 1, cv2.LINE_AA)
|
| 124 |
-
|
| 125 |
-
# Detail line
|
| 126 |
-
conf = result.get("mlp_prob", result.get("raw_score", 0.0))
|
| 127 |
-
mar_s = f" MAR:{result['mar']:.2f}" if result.get("mar") is not None else ""
|
| 128 |
-
sf = result.get("s_face", 0)
|
| 129 |
-
se = result.get("s_eye", 0)
|
| 130 |
-
detail = f"conf:{conf:.2f} S_face:{sf:.2f} S_eye:{se:.2f}{mar_s}"
|
| 131 |
-
cv2.putText(frame, detail, (10, 48), _FONT, 0.4, _WHITE, 1, cv2.LINE_AA)
|
| 132 |
-
|
| 133 |
-
# Head pose (top right)
|
| 134 |
-
if result.get("yaw") is not None:
|
| 135 |
-
cv2.putText(frame, f"yaw:{result['yaw']:+.0f} pitch:{result['pitch']:+.0f} roll:{result['roll']:+.0f}",
|
| 136 |
-
(w - 280, 48), _FONT, 0.4, (180, 180, 180), 1, cv2.LINE_AA)
|
| 137 |
-
|
| 138 |
-
# Yawn indicator
|
| 139 |
-
if result.get("is_yawning"):
|
| 140 |
-
cv2.putText(frame, "YAWN", (10, 75), _FONT, 0.7, _ORANGE, 2, cv2.LINE_AA)
|
| 141 |
-
|
| 142 |
-
# Landmark indices used for face mesh drawing on client (union of all groups).
|
| 143 |
-
# Sending only these instead of all 478 saves ~60% of the landmarks payload.
|
| 144 |
-
_MESH_INDICES = sorted(
|
| 145 |
-
set(
|
| 146 |
-
[
|
| 147 |
-
10, 338, 297, 332, 284, 251, 389, 356, 454,
|
| 148 |
-
323, 361, 288, 397, 365, 379, 378, 400, 377,
|
| 149 |
-
152, 148, 176, 149, 150, 136, 172, 58, 132,
|
| 150 |
-
93, 234, 127, 162, 21, 54, 103, 67, 109,
|
| 151 |
-
] # face oval
|
| 152 |
-
+ [33, 7, 163, 144, 145, 153, 154, 155, 133, 173, 157, 158, 159, 160, 161, 246] # left eye
|
| 153 |
-
+ [362, 382, 381, 380, 374, 373, 390, 249, 263, 466, 388, 387, 386, 385, 384, 398] # right eye
|
| 154 |
-
+ [468, 469, 470, 471, 472, 473, 474, 475, 476, 477] # irises
|
| 155 |
-
+ [70, 63, 105, 66, 107, 55, 65, 52, 53, 46] # left eyebrow
|
| 156 |
-
+ [300, 293, 334, 296, 336, 285, 295, 282, 283, 276] # right eyebrow
|
| 157 |
-
+ [6, 197, 195, 5, 4, 1, 19, 94, 2] # nose bridge
|
| 158 |
-
+ [61, 146, 91, 181, 84, 17, 314, 405, 321, 375, 291, 409, 270, 269, 267, 0, 37, 39, 40, 185] # lips outer
|
| 159 |
-
+ [78, 95, 88, 178, 87, 14, 317, 402, 318, 324, 308, 415, 310, 311, 312, 13, 82, 81, 80, 191] # lips inner
|
| 160 |
-
+ [33, 160, 158, 133, 153, 145] # left EAR key points
|
| 161 |
-
+ [362, 385, 387, 263, 373, 380] # right EAR key points
|
| 162 |
-
)
|
| 163 |
-
)
|
| 164 |
-
# Build a lookup: original_index -> position in sparse array, so client can reconstruct.
|
| 165 |
-
_MESH_INDEX_SET = set(_MESH_INDICES)
|
| 166 |
|
| 167 |
@asynccontextmanager
|
| 168 |
async def lifespan(app):
|
| 169 |
global _cached_model_name
|
| 170 |
print("Starting Focus Guard API")
|
| 171 |
-
await init_database()
|
| 172 |
async with aiosqlite.connect(db_path) as db:
|
| 173 |
cursor = await db.execute("SELECT model_name FROM user_settings WHERE id = 1")
|
| 174 |
row = await cursor.fetchone()
|
|
@@ -226,9 +120,8 @@ app.add_middleware(
|
|
| 226 |
)
|
| 227 |
|
| 228 |
# Global variables
|
| 229 |
-
db_path = "focus_guard.db"
|
| 230 |
pcs = set()
|
| 231 |
-
_cached_model_name = "mlp"
|
| 232 |
_l2cs_boost_enabled = False
|
| 233 |
|
| 234 |
async def _wait_for_ice_gathering(pc: RTCPeerConnection):
|
|
@@ -243,54 +136,6 @@ async def _wait_for_ice_gathering(pc: RTCPeerConnection):
|
|
| 243 |
|
| 244 |
await done.wait()
|
| 245 |
|
| 246 |
-
# ================ DATABASE MODELS ================
|
| 247 |
-
|
| 248 |
-
async def init_database():
|
| 249 |
-
"""Initialize SQLite database with required tables"""
|
| 250 |
-
async with aiosqlite.connect(db_path) as db:
|
| 251 |
-
# FocusSessions table
|
| 252 |
-
await db.execute("""
|
| 253 |
-
CREATE TABLE IF NOT EXISTS focus_sessions (
|
| 254 |
-
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 255 |
-
start_time TIMESTAMP NOT NULL,
|
| 256 |
-
end_time TIMESTAMP,
|
| 257 |
-
duration_seconds INTEGER DEFAULT 0,
|
| 258 |
-
focus_score REAL DEFAULT 0.0,
|
| 259 |
-
total_frames INTEGER DEFAULT 0,
|
| 260 |
-
focused_frames INTEGER DEFAULT 0,
|
| 261 |
-
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
| 262 |
-
)
|
| 263 |
-
""")
|
| 264 |
-
|
| 265 |
-
# FocusEvents table
|
| 266 |
-
await db.execute("""
|
| 267 |
-
CREATE TABLE IF NOT EXISTS focus_events (
|
| 268 |
-
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 269 |
-
session_id INTEGER NOT NULL,
|
| 270 |
-
timestamp TIMESTAMP NOT NULL,
|
| 271 |
-
is_focused BOOLEAN NOT NULL,
|
| 272 |
-
confidence REAL NOT NULL,
|
| 273 |
-
detection_data TEXT,
|
| 274 |
-
FOREIGN KEY (session_id) REFERENCES focus_sessions (id)
|
| 275 |
-
)
|
| 276 |
-
""")
|
| 277 |
-
|
| 278 |
-
# UserSettings table
|
| 279 |
-
await db.execute("""
|
| 280 |
-
CREATE TABLE IF NOT EXISTS user_settings (
|
| 281 |
-
id INTEGER PRIMARY KEY CHECK (id = 1),
|
| 282 |
-
model_name TEXT DEFAULT 'mlp'
|
| 283 |
-
)
|
| 284 |
-
""")
|
| 285 |
-
|
| 286 |
-
# Insert default settings if not exists
|
| 287 |
-
await db.execute("""
|
| 288 |
-
INSERT OR IGNORE INTO user_settings (id, model_name)
|
| 289 |
-
VALUES (1, 'mlp')
|
| 290 |
-
""")
|
| 291 |
-
|
| 292 |
-
await db.commit()
|
| 293 |
-
|
| 294 |
# ================ PYDANTIC MODELS ================
|
| 295 |
|
| 296 |
class SessionCreate(BaseModel):
|
|
@@ -319,8 +164,8 @@ class VideoTransformTrack(VideoStreamTrack):
|
|
| 319 |
if img is None:
|
| 320 |
return frame
|
| 321 |
|
| 322 |
-
|
| 323 |
-
img = cv2.resize(img, (
|
| 324 |
|
| 325 |
now = datetime.now().timestamp()
|
| 326 |
do_infer = (now - self.last_inference_time) >= self.min_inference_interval
|
|
@@ -357,8 +202,8 @@ class VideoTransformTrack(VideoStreamTrack):
|
|
| 357 |
h_f, w_f = img.shape[:2]
|
| 358 |
lm = out.get("landmarks")
|
| 359 |
if lm is not None:
|
| 360 |
-
|
| 361 |
-
|
| 362 |
else:
|
| 363 |
is_focused = False
|
| 364 |
confidence = 0.0
|
|
@@ -391,135 +236,6 @@ class VideoTransformTrack(VideoStreamTrack):
|
|
| 391 |
new_frame.time_base = frame.time_base
|
| 392 |
return new_frame
|
| 393 |
|
| 394 |
-
# ================ DATABASE OPERATIONS ================
|
| 395 |
-
|
| 396 |
-
async def create_session():
|
| 397 |
-
async with aiosqlite.connect(db_path) as db:
|
| 398 |
-
cursor = await db.execute(
|
| 399 |
-
"INSERT INTO focus_sessions (start_time) VALUES (?)",
|
| 400 |
-
(datetime.now().isoformat(),)
|
| 401 |
-
)
|
| 402 |
-
await db.commit()
|
| 403 |
-
return cursor.lastrowid
|
| 404 |
-
|
| 405 |
-
async def end_session(session_id: int):
|
| 406 |
-
async with aiosqlite.connect(db_path) as db:
|
| 407 |
-
cursor = await db.execute(
|
| 408 |
-
"SELECT start_time, total_frames, focused_frames FROM focus_sessions WHERE id = ?",
|
| 409 |
-
(session_id,)
|
| 410 |
-
)
|
| 411 |
-
row = await cursor.fetchone()
|
| 412 |
-
|
| 413 |
-
if not row:
|
| 414 |
-
return None
|
| 415 |
-
|
| 416 |
-
start_time_str, total_frames, focused_frames = row
|
| 417 |
-
start_time = datetime.fromisoformat(start_time_str)
|
| 418 |
-
end_time = datetime.now()
|
| 419 |
-
duration = (end_time - start_time).total_seconds()
|
| 420 |
-
focus_score = focused_frames / total_frames if total_frames > 0 else 0.0
|
| 421 |
-
|
| 422 |
-
await db.execute("""
|
| 423 |
-
UPDATE focus_sessions
|
| 424 |
-
SET end_time = ?, duration_seconds = ?, focus_score = ?
|
| 425 |
-
WHERE id = ?
|
| 426 |
-
""", (end_time.isoformat(), int(duration), focus_score, session_id))
|
| 427 |
-
|
| 428 |
-
await db.commit()
|
| 429 |
-
|
| 430 |
-
return {
|
| 431 |
-
'session_id': session_id,
|
| 432 |
-
'start_time': start_time_str,
|
| 433 |
-
'end_time': end_time.isoformat(),
|
| 434 |
-
'duration_seconds': int(duration),
|
| 435 |
-
'focus_score': round(focus_score, 3),
|
| 436 |
-
'total_frames': total_frames,
|
| 437 |
-
'focused_frames': focused_frames
|
| 438 |
-
}
|
| 439 |
-
|
| 440 |
-
async def store_focus_event(session_id: int, is_focused: bool, confidence: float, metadata: dict):
|
| 441 |
-
async with aiosqlite.connect(db_path) as db:
|
| 442 |
-
await db.execute("""
|
| 443 |
-
INSERT INTO focus_events (session_id, timestamp, is_focused, confidence, detection_data)
|
| 444 |
-
VALUES (?, ?, ?, ?, ?)
|
| 445 |
-
""", (session_id, datetime.now().isoformat(), is_focused, confidence, json.dumps(metadata)))
|
| 446 |
-
|
| 447 |
-
await db.execute("""
|
| 448 |
-
UPDATE focus_sessions
|
| 449 |
-
SET total_frames = total_frames + 1,
|
| 450 |
-
focused_frames = focused_frames + ?
|
| 451 |
-
WHERE id = ?
|
| 452 |
-
""", (1 if is_focused else 0, session_id))
|
| 453 |
-
await db.commit()
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
class _EventBuffer:
|
| 457 |
-
"""Buffer focus events in memory and flush to DB in batches to avoid per-frame DB writes."""
|
| 458 |
-
|
| 459 |
-
def __init__(self, flush_interval: float = 2.0):
|
| 460 |
-
self._buf: list = []
|
| 461 |
-
self._lock = asyncio.Lock()
|
| 462 |
-
self._flush_interval = flush_interval
|
| 463 |
-
self._task: asyncio.Task | None = None
|
| 464 |
-
self._total_frames = 0
|
| 465 |
-
self._focused_frames = 0
|
| 466 |
-
|
| 467 |
-
def start(self):
|
| 468 |
-
if self._task is None:
|
| 469 |
-
self._task = asyncio.create_task(self._flush_loop())
|
| 470 |
-
|
| 471 |
-
async def stop(self):
|
| 472 |
-
if self._task:
|
| 473 |
-
self._task.cancel()
|
| 474 |
-
try:
|
| 475 |
-
await self._task
|
| 476 |
-
except asyncio.CancelledError:
|
| 477 |
-
pass
|
| 478 |
-
self._task = None
|
| 479 |
-
await self._flush()
|
| 480 |
-
|
| 481 |
-
def add(self, session_id: int, is_focused: bool, confidence: float, metadata: dict):
|
| 482 |
-
self._buf.append((session_id, datetime.now().isoformat(), is_focused, confidence, json.dumps(metadata)))
|
| 483 |
-
self._total_frames += 1
|
| 484 |
-
if is_focused:
|
| 485 |
-
self._focused_frames += 1
|
| 486 |
-
|
| 487 |
-
async def _flush_loop(self):
|
| 488 |
-
while True:
|
| 489 |
-
await asyncio.sleep(self._flush_interval)
|
| 490 |
-
await self._flush()
|
| 491 |
-
|
| 492 |
-
async def _flush(self):
|
| 493 |
-
async with self._lock:
|
| 494 |
-
if not self._buf:
|
| 495 |
-
return
|
| 496 |
-
batch = self._buf[:]
|
| 497 |
-
total = self._total_frames
|
| 498 |
-
focused = self._focused_frames
|
| 499 |
-
self._buf.clear()
|
| 500 |
-
self._total_frames = 0
|
| 501 |
-
self._focused_frames = 0
|
| 502 |
-
|
| 503 |
-
if not batch:
|
| 504 |
-
return
|
| 505 |
-
|
| 506 |
-
session_id = batch[0][0]
|
| 507 |
-
try:
|
| 508 |
-
async with aiosqlite.connect(db_path) as db:
|
| 509 |
-
await db.executemany("""
|
| 510 |
-
INSERT INTO focus_events (session_id, timestamp, is_focused, confidence, detection_data)
|
| 511 |
-
VALUES (?, ?, ?, ?, ?)
|
| 512 |
-
""", batch)
|
| 513 |
-
await db.execute("""
|
| 514 |
-
UPDATE focus_sessions
|
| 515 |
-
SET total_frames = total_frames + ?,
|
| 516 |
-
focused_frames = focused_frames + ?
|
| 517 |
-
WHERE id = ?
|
| 518 |
-
""", (total, focused, session_id))
|
| 519 |
-
await db.commit()
|
| 520 |
-
except Exception as e:
|
| 521 |
-
print(f"[DB] Flush error: {e}")
|
| 522 |
-
|
| 523 |
# ================ STARTUP/SHUTDOWN ================
|
| 524 |
|
| 525 |
pipelines = {
|
|
@@ -532,7 +248,7 @@ pipelines = {
|
|
| 532 |
|
| 533 |
# Thread pool for CPU-bound inference so the event loop stays responsive.
|
| 534 |
_inference_executor = concurrent.futures.ThreadPoolExecutor(
|
| 535 |
-
max_workers=
|
| 536 |
thread_name_prefix="inference",
|
| 537 |
)
|
| 538 |
# One lock per pipeline so shared state (TemporalTracker, etc.) is not corrupted when
|
|
@@ -607,7 +323,7 @@ def _process_frame_with_l2cs_boost(base_pipeline, frame, base_model_name):
|
|
| 607 |
is_focused = False
|
| 608 |
else:
|
| 609 |
fused_score = _BOOST_BASE_W * base_score + _BOOST_L2CS_W * l2cs_score
|
| 610 |
-
is_focused = fused_score >=
|
| 611 |
|
| 612 |
base_out["raw_score"] = fused_score
|
| 613 |
base_out["is_focused"] = is_focused
|
|
@@ -680,7 +396,7 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
| 680 |
session_id = None
|
| 681 |
frame_count = 0
|
| 682 |
running = True
|
| 683 |
-
event_buffer =
|
| 684 |
|
| 685 |
# Calibration state (per-connection)
|
| 686 |
# verifying: after fit, show a verification target and check gaze accuracy
|
|
@@ -855,7 +571,7 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
| 855 |
frame = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
| 856 |
if frame is None:
|
| 857 |
continue
|
| 858 |
-
frame = cv2.resize(frame, (
|
| 859 |
|
| 860 |
# During calibration collection, always use L2CS
|
| 861 |
collecting = _cal.get("collecting", False)
|
|
@@ -937,7 +653,7 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
| 937 |
elif use_boost and not fuse["on_screen"]:
|
| 938 |
# Boost mode: if gaze is clearly off-screen, override to unfocused
|
| 939 |
is_focused = False
|
| 940 |
-
confidence = min(confidence,
|
| 941 |
|
| 942 |
if session_id:
|
| 943 |
metadata = {
|
|
@@ -980,7 +696,7 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
| 980 |
resp["confidence"] = round(fuse["focus_score"], 3)
|
| 981 |
elif use_boost and not fuse["on_screen"]:
|
| 982 |
resp["focused"] = False
|
| 983 |
-
resp["confidence"] = min(resp["confidence"],
|
| 984 |
if has_gaze:
|
| 985 |
resp["gaze_yaw"] = round(out["gaze_yaw"], 4)
|
| 986 |
resp["gaze_pitch"] = round(out["gaze_pitch"], 4)
|
|
@@ -1133,7 +849,7 @@ async def update_settings(settings: SettingsUpdate):
|
|
| 1133 |
cursor = await db.execute("SELECT id FROM user_settings WHERE id = 1")
|
| 1134 |
exists = await cursor.fetchone()
|
| 1135 |
if not exists:
|
| 1136 |
-
await db.execute("INSERT INTO user_settings (id,
|
| 1137 |
await db.commit()
|
| 1138 |
|
| 1139 |
updates = []
|
|
@@ -1278,7 +994,7 @@ async def l2cs_status():
|
|
| 1278 |
@app.get("/api/mesh-topology")
|
| 1279 |
async def get_mesh_topology():
|
| 1280 |
"""Return tessellation edge pairs for client-side face mesh drawing (cached by client)."""
|
| 1281 |
-
return {"tessellation":
|
| 1282 |
|
| 1283 |
@app.get("/health")
|
| 1284 |
async def health_check():
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
+
import asyncio
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
import base64
|
| 5 |
+
import concurrent.futures
|
|
|
|
|
|
|
| 6 |
import json
|
| 7 |
+
import logging
|
|
|
|
| 8 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
import threading
|
| 10 |
+
from contextlib import asynccontextmanager
|
| 11 |
+
from datetime import datetime, timedelta
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
from typing import Any, Callable, List, Optional
|
| 14 |
|
| 15 |
+
import aiosqlite
|
| 16 |
+
import cv2
|
| 17 |
+
import numpy as np
|
| 18 |
from aiortc import RTCPeerConnection, RTCSessionDescription, VideoStreamTrack
|
|
|
|
|
|
|
| 19 |
from av import VideoFrame
|
| 20 |
+
from fastapi import FastAPI, HTTPException, Request, WebSocket, WebSocketDisconnect
|
| 21 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 22 |
+
from fastapi.responses import FileResponse
|
| 23 |
+
from fastapi.staticfiles import StaticFiles
|
| 24 |
+
from pydantic import BaseModel
|
| 25 |
|
| 26 |
+
from api.drawing import draw_face_mesh, draw_hud, get_tesselation_connections
|
| 27 |
+
from api.db import (
|
| 28 |
+
EventBuffer,
|
| 29 |
+
create_session,
|
| 30 |
+
end_session,
|
| 31 |
+
get_db_path,
|
| 32 |
+
init_database,
|
| 33 |
+
store_focus_event,
|
| 34 |
)
|
| 35 |
+
from config import get
|
| 36 |
from models.face_mesh import FaceMeshDetector
|
| 37 |
+
from ui.pipeline import (
|
| 38 |
+
FaceMeshPipeline,
|
| 39 |
+
HybridFocusPipeline,
|
| 40 |
+
L2CSPipeline,
|
| 41 |
+
MLPPipeline,
|
| 42 |
+
XGBoostPipeline,
|
| 43 |
+
is_l2cs_weights_available,
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
logger = logging.getLogger(__name__)
|
| 47 |
|
| 48 |
+
db_path = get_db_path()
|
| 49 |
+
_inference_size = get("app.inference_size") or [640, 480]
|
| 50 |
+
_inference_workers = get("app.inference_workers") or 4
|
| 51 |
+
_fused_threshold = get("l2cs_boost.fused_threshold") or 0.52
|
| 52 |
+
_no_face_cap = get("app.no_face_confidence_cap") or 0.1
|
| 53 |
+
_BOOST_BASE_W = get("l2cs_boost.base_weight") or 0.35
|
| 54 |
+
_BOOST_L2CS_W = get("l2cs_boost.l2cs_weight") or 0.65
|
| 55 |
+
_BOOST_VETO = get("l2cs_boost.veto_threshold") or 0.38
|
| 56 |
|
| 57 |
_FONT = cv2.FONT_HERSHEY_SIMPLEX
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
_RED = (0, 0, 255)
|
| 59 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
@asynccontextmanager
|
| 62 |
async def lifespan(app):
|
| 63 |
global _cached_model_name
|
| 64 |
print("Starting Focus Guard API")
|
| 65 |
+
await init_database(db_path)
|
| 66 |
async with aiosqlite.connect(db_path) as db:
|
| 67 |
cursor = await db.execute("SELECT model_name FROM user_settings WHERE id = 1")
|
| 68 |
row = await cursor.fetchone()
|
|
|
|
| 120 |
)
|
| 121 |
|
| 122 |
# Global variables
|
|
|
|
| 123 |
pcs = set()
|
| 124 |
+
_cached_model_name = get("app.default_model") or "mlp"
|
| 125 |
_l2cs_boost_enabled = False
|
| 126 |
|
| 127 |
async def _wait_for_ice_gathering(pc: RTCPeerConnection):
|
|
|
|
| 136 |
|
| 137 |
await done.wait()
|
| 138 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
# ================ PYDANTIC MODELS ================
|
| 140 |
|
| 141 |
class SessionCreate(BaseModel):
|
|
|
|
| 164 |
if img is None:
|
| 165 |
return frame
|
| 166 |
|
| 167 |
+
w_sz, h_sz = _inference_size[0], _inference_size[1]
|
| 168 |
+
img = cv2.resize(img, (w_sz, h_sz))
|
| 169 |
|
| 170 |
now = datetime.now().timestamp()
|
| 171 |
do_infer = (now - self.last_inference_time) >= self.min_inference_interval
|
|
|
|
| 202 |
h_f, w_f = img.shape[:2]
|
| 203 |
lm = out.get("landmarks")
|
| 204 |
if lm is not None:
|
| 205 |
+
draw_face_mesh(img, lm, w_f, h_f)
|
| 206 |
+
draw_hud(img, out, model_name)
|
| 207 |
else:
|
| 208 |
is_focused = False
|
| 209 |
confidence = 0.0
|
|
|
|
| 236 |
new_frame.time_base = frame.time_base
|
| 237 |
return new_frame
|
| 238 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 239 |
# ================ STARTUP/SHUTDOWN ================
|
| 240 |
|
| 241 |
pipelines = {
|
|
|
|
| 248 |
|
| 249 |
# Thread pool for CPU-bound inference so the event loop stays responsive.
|
| 250 |
_inference_executor = concurrent.futures.ThreadPoolExecutor(
|
| 251 |
+
max_workers=_inference_workers,
|
| 252 |
thread_name_prefix="inference",
|
| 253 |
)
|
| 254 |
# One lock per pipeline so shared state (TemporalTracker, etc.) is not corrupted when
|
|
|
|
| 323 |
is_focused = False
|
| 324 |
else:
|
| 325 |
fused_score = _BOOST_BASE_W * base_score + _BOOST_L2CS_W * l2cs_score
|
| 326 |
+
is_focused = fused_score >= _fused_threshold
|
| 327 |
|
| 328 |
base_out["raw_score"] = fused_score
|
| 329 |
base_out["is_focused"] = is_focused
|
|
|
|
| 396 |
session_id = None
|
| 397 |
frame_count = 0
|
| 398 |
running = True
|
| 399 |
+
event_buffer = EventBuffer(db_path=db_path, flush_interval=2.0)
|
| 400 |
|
| 401 |
# Calibration state (per-connection)
|
| 402 |
# verifying: after fit, show a verification target and check gaze accuracy
|
|
|
|
| 571 |
frame = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
| 572 |
if frame is None:
|
| 573 |
continue
|
| 574 |
+
frame = cv2.resize(frame, (_inference_size[0], _inference_size[1]))
|
| 575 |
|
| 576 |
# During calibration collection, always use L2CS
|
| 577 |
collecting = _cal.get("collecting", False)
|
|
|
|
| 653 |
elif use_boost and not fuse["on_screen"]:
|
| 654 |
# Boost mode: if gaze is clearly off-screen, override to unfocused
|
| 655 |
is_focused = False
|
| 656 |
+
confidence = min(confidence, _no_face_cap)
|
| 657 |
|
| 658 |
if session_id:
|
| 659 |
metadata = {
|
|
|
|
| 696 |
resp["confidence"] = round(fuse["focus_score"], 3)
|
| 697 |
elif use_boost and not fuse["on_screen"]:
|
| 698 |
resp["focused"] = False
|
| 699 |
+
resp["confidence"] = min(resp["confidence"], _no_face_cap)
|
| 700 |
if has_gaze:
|
| 701 |
resp["gaze_yaw"] = round(out["gaze_yaw"], 4)
|
| 702 |
resp["gaze_pitch"] = round(out["gaze_pitch"], 4)
|
|
|
|
| 849 |
cursor = await db.execute("SELECT id FROM user_settings WHERE id = 1")
|
| 850 |
exists = await cursor.fetchone()
|
| 851 |
if not exists:
|
| 852 |
+
await db.execute("INSERT INTO user_settings (id, model_name) VALUES (1, 'mlp')")
|
| 853 |
await db.commit()
|
| 854 |
|
| 855 |
updates = []
|
|
|
|
| 994 |
@app.get("/api/mesh-topology")
|
| 995 |
async def get_mesh_topology():
|
| 996 |
"""Return tessellation edge pairs for client-side face mesh drawing (cached by client)."""
|
| 997 |
+
return {"tessellation": get_tesselation_connections()}
|
| 998 |
|
| 999 |
@app.get("/health")
|
| 1000 |
async def health_check():
|
models/L2CS-Net/l2cs/datasets.py
CHANGED
|
@@ -59,11 +59,6 @@ class Gaze360(Dataset):
|
|
| 59 |
|
| 60 |
img = Image.open(os.path.join(self.root, face))
|
| 61 |
|
| 62 |
-
# fimg = cv2.imread(os.path.join(self.root, face))
|
| 63 |
-
# fimg = cv2.resize(fimg, (448, 448))/255.0
|
| 64 |
-
# fimg = fimg.transpose(2, 0, 1)
|
| 65 |
-
# img=torch.from_numpy(fimg).type(torch.FloatTensor)
|
| 66 |
-
|
| 67 |
if self.transform:
|
| 68 |
img = self.transform(img)
|
| 69 |
|
|
@@ -135,11 +130,6 @@ class Mpiigaze(Dataset):
|
|
| 135 |
|
| 136 |
img = Image.open(os.path.join(self.root, face))
|
| 137 |
|
| 138 |
-
# fimg = cv2.imread(os.path.join(self.root, face))
|
| 139 |
-
# fimg = cv2.resize(fimg, (448, 448))/255.0
|
| 140 |
-
# fimg = fimg.transpose(2, 0, 1)
|
| 141 |
-
# img=torch.from_numpy(fimg).type(torch.FloatTensor)
|
| 142 |
-
|
| 143 |
if self.transform:
|
| 144 |
img = self.transform(img)
|
| 145 |
|
|
|
|
| 59 |
|
| 60 |
img = Image.open(os.path.join(self.root, face))
|
| 61 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
if self.transform:
|
| 63 |
img = self.transform(img)
|
| 64 |
|
|
|
|
| 130 |
|
| 131 |
img = Image.open(os.path.join(self.root, face))
|
| 132 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
if self.transform:
|
| 134 |
img = self.transform(img)
|
| 135 |
|
models/mlp/eval_accuracy.py
CHANGED
|
@@ -25,8 +25,6 @@ def main():
|
|
| 25 |
train_loader, val_loader, test_loader, num_features, num_classes, _ = get_dataloaders(
|
| 26 |
model_name="face_orientation",
|
| 27 |
batch_size=32,
|
| 28 |
-
split_ratios=(0.7, 0.15, 0.15),
|
| 29 |
-
seed=42,
|
| 30 |
)
|
| 31 |
|
| 32 |
model = BaseModel(num_features, num_classes).to(device)
|
|
|
|
| 25 |
train_loader, val_loader, test_loader, num_features, num_classes, _ = get_dataloaders(
|
| 26 |
model_name="face_orientation",
|
| 27 |
batch_size=32,
|
|
|
|
|
|
|
| 28 |
)
|
| 29 |
|
| 30 |
model = BaseModel(num_features, num_classes).to(device)
|
models/mlp/sweep.py
CHANGED
|
@@ -14,10 +14,10 @@ REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
|
|
| 14 |
if REPO_ROOT not in sys.path:
|
| 15 |
sys.path.insert(0, REPO_ROOT)
|
| 16 |
|
| 17 |
-
from data_preparation.prepare_dataset import get_dataloaders
|
| 18 |
from models.mlp.train import BaseModel, set_seed
|
| 19 |
|
| 20 |
-
SEED =
|
| 21 |
N_TRIALS = 20
|
| 22 |
EPOCHS_PER_TRIAL = 15
|
| 23 |
|
|
@@ -31,7 +31,7 @@ def objective(trial):
|
|
| 31 |
train_loader, val_loader, _, num_features, num_classes, _ = get_dataloaders(
|
| 32 |
model_name="face_orientation",
|
| 33 |
batch_size=batch_size,
|
| 34 |
-
split_ratios=
|
| 35 |
seed=SEED,
|
| 36 |
)
|
| 37 |
|
|
|
|
| 14 |
if REPO_ROOT not in sys.path:
|
| 15 |
sys.path.insert(0, REPO_ROOT)
|
| 16 |
|
| 17 |
+
from data_preparation.prepare_dataset import get_default_split_config, get_dataloaders
|
| 18 |
from models.mlp.train import BaseModel, set_seed
|
| 19 |
|
| 20 |
+
SPLIT_RATIOS, SEED = get_default_split_config()
|
| 21 |
N_TRIALS = 20
|
| 22 |
EPOCHS_PER_TRIAL = 15
|
| 23 |
|
|
|
|
| 31 |
train_loader, val_loader, _, num_features, num_classes, _ = get_dataloaders(
|
| 32 |
model_name="face_orientation",
|
| 33 |
batch_size=batch_size,
|
| 34 |
+
split_ratios=SPLIT_RATIOS,
|
| 35 |
seed=SEED,
|
| 36 |
)
|
| 37 |
|
models/mlp/train.py
CHANGED
|
@@ -1,46 +1,95 @@
|
|
| 1 |
import json
|
| 2 |
import os
|
| 3 |
import random
|
|
|
|
| 4 |
|
| 5 |
import numpy as np
|
| 6 |
import joblib
|
| 7 |
import torch
|
| 8 |
import torch.nn as nn
|
| 9 |
import torch.optim as optim
|
| 10 |
-
from sklearn.metrics import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
from data_preparation.prepare_dataset import get_dataloaders, SELECTED_FEATURES
|
| 13 |
|
| 14 |
-
USE_CLEARML = False
|
| 15 |
-
|
| 16 |
_PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
"
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
}
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
task = None
|
| 31 |
if USE_CLEARML:
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
|
| 41 |
|
| 42 |
# ==== Model =============================================
|
| 43 |
-
def set_seed(seed: int):
|
|
|
|
| 44 |
random.seed(seed)
|
| 45 |
np.random.seed(seed)
|
| 46 |
torch.manual_seed(seed)
|
|
@@ -49,15 +98,18 @@ def set_seed(seed: int):
|
|
| 49 |
|
| 50 |
|
| 51 |
class BaseModel(nn.Module):
|
| 52 |
-
|
|
|
|
|
|
|
| 53 |
super().__init__()
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
nn.
|
| 59 |
-
|
| 60 |
-
)
|
|
|
|
| 61 |
|
| 62 |
def forward(self, x):
|
| 63 |
return self.network(x)
|
|
@@ -89,6 +141,8 @@ class BaseModel(nn.Module):
|
|
| 89 |
total_loss = 0.0
|
| 90 |
correct = 0
|
| 91 |
total = 0
|
|
|
|
|
|
|
| 92 |
|
| 93 |
for features, labels in loader:
|
| 94 |
features, labels = features.to(device), labels.to(device)
|
|
@@ -96,10 +150,14 @@ class BaseModel(nn.Module):
|
|
| 96 |
loss = criterion(outputs, labels)
|
| 97 |
|
| 98 |
total_loss += loss.item() * features.size(0)
|
| 99 |
-
|
|
|
|
| 100 |
total += features.size(0)
|
|
|
|
|
|
|
| 101 |
|
| 102 |
-
|
|
|
|
| 103 |
|
| 104 |
@torch.no_grad()
|
| 105 |
def test_step(self, loader, criterion, device):
|
|
@@ -130,7 +188,8 @@ class BaseModel(nn.Module):
|
|
| 130 |
return total_loss / total, correct / total, np.array(all_probs), np.array(all_preds), np.array(all_labels)
|
| 131 |
|
| 132 |
|
| 133 |
-
def main():
|
|
|
|
| 134 |
set_seed(CFG["seed"])
|
| 135 |
|
| 136 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
@@ -144,7 +203,7 @@ def main():
|
|
| 144 |
seed=CFG["seed"],
|
| 145 |
)
|
| 146 |
|
| 147 |
-
model = BaseModel(num_features, num_classes).to(device)
|
| 148 |
criterion = nn.CrossEntropyLoss()
|
| 149 |
optimizer = optim.Adam(model.parameters(), lr=CFG["lr"])
|
| 150 |
|
|
@@ -163,22 +222,25 @@ def main():
|
|
| 163 |
"train_acc": [],
|
| 164 |
"val_loss": [],
|
| 165 |
"val_acc": [],
|
|
|
|
| 166 |
}
|
| 167 |
|
|
|
|
| 168 |
best_val_acc = 0.0
|
| 169 |
|
| 170 |
-
print(f"\n{'Epoch':>6} | {'Train Loss':>10} | {'Train Acc':>9} | {'Val Loss':>10} | {'Val Acc':>9}")
|
| 171 |
-
print("-" *
|
| 172 |
|
| 173 |
for epoch in range(1, CFG["epochs"] + 1):
|
| 174 |
train_loss, train_acc = model.training_step(train_loader, optimizer, criterion, device)
|
| 175 |
-
val_loss, val_acc = model.validation_step(val_loader, criterion, device)
|
| 176 |
|
| 177 |
history["epochs"].append(epoch)
|
| 178 |
history["train_loss"].append(round(train_loss, 4))
|
| 179 |
history["train_acc"].append(round(train_acc, 4))
|
| 180 |
history["val_loss"].append(round(val_loss, 4))
|
| 181 |
history["val_acc"].append(round(val_acc, 4))
|
|
|
|
| 182 |
|
| 183 |
|
| 184 |
current_lr = optimizer.param_groups[0]['lr']
|
|
@@ -187,30 +249,36 @@ def main():
|
|
| 187 |
task.logger.report_scalar("Accuracy", "Train", float(train_acc), iteration=epoch)
|
| 188 |
task.logger.report_scalar("Loss", "Val", float(val_loss), iteration=epoch)
|
| 189 |
task.logger.report_scalar("Accuracy", "Val", float(val_acc), iteration=epoch)
|
|
|
|
| 190 |
task.logger.report_scalar("Learning Rate", "LR", float(current_lr), iteration=epoch)
|
| 191 |
task.logger.flush()
|
| 192 |
|
| 193 |
marker = ""
|
| 194 |
-
if
|
|
|
|
| 195 |
best_val_acc = val_acc
|
| 196 |
torch.save(model.state_dict(), best_ckpt_path)
|
| 197 |
marker = " *"
|
| 198 |
|
| 199 |
-
print(
|
|
|
|
|
|
|
|
|
|
| 200 |
|
| 201 |
-
print(f"\nBest validation accuracy: {best_val_acc:.2%}")
|
| 202 |
print(f"Checkpoint saved to: {best_ckpt_path}")
|
| 203 |
|
| 204 |
model.load_state_dict(torch.load(best_ckpt_path, weights_only=True))
|
| 205 |
test_loss, test_acc, test_probs, test_preds, test_labels = model.test_step(test_loader, criterion, device)
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
|
|
|
| 209 |
if num_classes > 2:
|
| 210 |
-
test_auc = roc_auc_score(
|
| 211 |
else:
|
| 212 |
-
test_auc = roc_auc_score(
|
| 213 |
-
|
| 214 |
print(f"\n[TEST] Loss: {test_loss:.4f} | Accuracy: {test_acc:.2%}")
|
| 215 |
print(f"[TEST] F1: {test_f1:.4f} | ROC-AUC: {test_auc:.4f}")
|
| 216 |
|
|
@@ -219,22 +287,72 @@ def main():
|
|
| 219 |
history["test_f1"] = round(test_f1, 4)
|
| 220 |
history["test_auc"] = round(test_auc, 4)
|
| 221 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 222 |
logs_dir = CFG["logs_dir"]
|
| 223 |
os.makedirs(logs_dir, exist_ok=True)
|
| 224 |
log_path = os.path.join(logs_dir, f"{CFG['model_name']}_training_log.json")
|
| 225 |
-
|
| 226 |
with open(log_path, "w") as f:
|
| 227 |
json.dump(history, f, indent=2)
|
| 228 |
-
|
| 229 |
print(f"[LOG] Training history saved to: {log_path}")
|
| 230 |
|
| 231 |
-
# Save scaler and feature names for inference (ui/pipeline.py)
|
| 232 |
scaler_path = os.path.join(ckpt_dir, "scaler_mlp.joblib")
|
| 233 |
joblib.dump(scaler, scaler_path)
|
| 234 |
meta_path = os.path.join(ckpt_dir, "meta_mlp.npz")
|
| 235 |
np.savez(meta_path, feature_names=np.array(SELECTED_FEATURES["face_orientation"]))
|
| 236 |
print(f"[LOG] Scaler and meta saved to {ckpt_dir}")
|
| 237 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 238 |
|
| 239 |
if __name__ == "__main__":
|
| 240 |
main()
|
|
|
|
| 1 |
import json
|
| 2 |
import os
|
| 3 |
import random
|
| 4 |
+
import sys
|
| 5 |
|
| 6 |
import numpy as np
|
| 7 |
import joblib
|
| 8 |
import torch
|
| 9 |
import torch.nn as nn
|
| 10 |
import torch.optim as optim
|
| 11 |
+
from sklearn.metrics import (
|
| 12 |
+
confusion_matrix,
|
| 13 |
+
f1_score,
|
| 14 |
+
precision_recall_fscore_support,
|
| 15 |
+
roc_auc_score,
|
| 16 |
+
)
|
| 17 |
|
| 18 |
from data_preparation.prepare_dataset import get_dataloaders, SELECTED_FEATURES
|
| 19 |
|
|
|
|
|
|
|
| 20 |
_PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
|
| 21 |
+
|
| 22 |
+
USE_CLEARML = os.environ.get("USE_CLEARML", "0") == "1" or bool(os.environ.get("CLEARML_TASK_ID"))
|
| 23 |
+
CLEARML_QUEUE = os.environ.get("CLEARML_QUEUE", "")
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def _load_cfg():
|
| 27 |
+
"""Build training config from config/default.yaml with fallbacks."""
|
| 28 |
+
try:
|
| 29 |
+
from config import get
|
| 30 |
+
mlp = get("mlp") or {}
|
| 31 |
+
data = get("data") or {}
|
| 32 |
+
ratios = data.get("split_ratios", [0.7, 0.15, 0.15])
|
| 33 |
+
return {
|
| 34 |
+
"model_name": mlp.get("model_name", "face_orientation"),
|
| 35 |
+
"epochs": mlp.get("epochs", 30),
|
| 36 |
+
"batch_size": mlp.get("batch_size", 32),
|
| 37 |
+
"lr": mlp.get("lr", 1e-3),
|
| 38 |
+
"seed": mlp.get("seed", 42),
|
| 39 |
+
"split_ratios": tuple(ratios),
|
| 40 |
+
"hidden_sizes": mlp.get("hidden_sizes", [64, 32]),
|
| 41 |
+
"checkpoints_dir": os.path.join(_PROJECT_ROOT, "checkpoints"),
|
| 42 |
+
"logs_dir": os.path.join(_PROJECT_ROOT, "evaluation", "logs"),
|
| 43 |
+
}
|
| 44 |
+
except Exception:
|
| 45 |
+
return {
|
| 46 |
+
"model_name": "face_orientation",
|
| 47 |
+
"epochs": 30,
|
| 48 |
+
"batch_size": 32,
|
| 49 |
+
"lr": 1e-3,
|
| 50 |
+
"seed": 42,
|
| 51 |
+
"split_ratios": (0.7, 0.15, 0.15),
|
| 52 |
+
"hidden_sizes": [64, 32],
|
| 53 |
+
"checkpoints_dir": os.path.join(_PROJECT_ROOT, "checkpoints"),
|
| 54 |
+
"logs_dir": os.path.join(_PROJECT_ROOT, "evaluation", "logs"),
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
CFG = _load_cfg()
|
| 59 |
+
|
| 60 |
+
# ==== ClearML: expose all config as task params, support remote execution ====
|
| 61 |
task = None
|
| 62 |
if USE_CLEARML:
|
| 63 |
+
try:
|
| 64 |
+
from clearml import Task
|
| 65 |
+
from config import flatten_for_clearml
|
| 66 |
+
task = Task.init(
|
| 67 |
+
project_name="Focus Guard",
|
| 68 |
+
task_name="MLP Model Training",
|
| 69 |
+
tags=["training", "mlp_model"],
|
| 70 |
+
)
|
| 71 |
+
flat = flatten_for_clearml()
|
| 72 |
+
flat["mlp/model_name"] = CFG.get("model_name", "face_orientation")
|
| 73 |
+
flat["mlp/epochs"] = CFG.get("epochs", 30)
|
| 74 |
+
flat["mlp/batch_size"] = CFG.get("batch_size", 32)
|
| 75 |
+
flat["mlp/lr"] = CFG.get("lr", 1e-3)
|
| 76 |
+
flat["mlp/seed"] = CFG.get("seed", 42)
|
| 77 |
+
flat["mlp/hidden_sizes"] = str(CFG.get("hidden_sizes", [64, 32]))
|
| 78 |
+
flat["mlp/split_ratios"] = str(CFG.get("split_ratios", (0.7, 0.15, 0.15)))
|
| 79 |
+
task.connect(flat)
|
| 80 |
+
if CLEARML_QUEUE:
|
| 81 |
+
print(f"[ClearML] Enqueuing to queue '{CLEARML_QUEUE}'. Agent will run training.")
|
| 82 |
+
task.execute_remotely(queue_name=CLEARML_QUEUE)
|
| 83 |
+
sys.exit(0)
|
| 84 |
+
except ImportError:
|
| 85 |
+
task = None
|
| 86 |
+
USE_CLEARML = False
|
| 87 |
|
| 88 |
|
| 89 |
|
| 90 |
# ==== Model =============================================
|
| 91 |
+
def set_seed(seed: int) -> None:
|
| 92 |
+
"""Set random seed for numpy, torch, and Python RNG for reproducibility."""
|
| 93 |
random.seed(seed)
|
| 94 |
np.random.seed(seed)
|
| 95 |
torch.manual_seed(seed)
|
|
|
|
| 98 |
|
| 99 |
|
| 100 |
class BaseModel(nn.Module):
|
| 101 |
+
"""MLP classifier: num_features -> hidden_sizes -> num_classes. Used for face_orientation focus."""
|
| 102 |
+
|
| 103 |
+
def __init__(self, num_features: int, num_classes: int, hidden_sizes: list[int] | None = None):
|
| 104 |
super().__init__()
|
| 105 |
+
sizes = hidden_sizes or CFG.get("hidden_sizes", [64, 32])
|
| 106 |
+
layers = []
|
| 107 |
+
prev = num_features
|
| 108 |
+
for h in sizes:
|
| 109 |
+
layers.extend([nn.Linear(prev, h), nn.ReLU()])
|
| 110 |
+
prev = h
|
| 111 |
+
layers.append(nn.Linear(prev, num_classes))
|
| 112 |
+
self.network = nn.Sequential(*layers)
|
| 113 |
|
| 114 |
def forward(self, x):
|
| 115 |
return self.network(x)
|
|
|
|
| 141 |
total_loss = 0.0
|
| 142 |
correct = 0
|
| 143 |
total = 0
|
| 144 |
+
all_preds = []
|
| 145 |
+
all_labels = []
|
| 146 |
|
| 147 |
for features, labels in loader:
|
| 148 |
features, labels = features.to(device), labels.to(device)
|
|
|
|
| 150 |
loss = criterion(outputs, labels)
|
| 151 |
|
| 152 |
total_loss += loss.item() * features.size(0)
|
| 153 |
+
preds = outputs.argmax(dim=1)
|
| 154 |
+
correct += (preds == labels).sum().item()
|
| 155 |
total += features.size(0)
|
| 156 |
+
all_preds.extend(preds.cpu().numpy())
|
| 157 |
+
all_labels.extend(labels.cpu().numpy())
|
| 158 |
|
| 159 |
+
val_f1 = f1_score(np.array(all_labels), np.array(all_preds), average="weighted")
|
| 160 |
+
return total_loss / total, correct / total, val_f1
|
| 161 |
|
| 162 |
@torch.no_grad()
|
| 163 |
def test_step(self, loader, criterion, device):
|
|
|
|
| 188 |
return total_loss / total, correct / total, np.array(all_probs), np.array(all_preds), np.array(all_labels)
|
| 189 |
|
| 190 |
|
| 191 |
+
def main() -> None:
|
| 192 |
+
"""Train MLP on face_orientation features, save best checkpoint and scaler to checkpoints/."""
|
| 193 |
set_seed(CFG["seed"])
|
| 194 |
|
| 195 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
| 203 |
seed=CFG["seed"],
|
| 204 |
)
|
| 205 |
|
| 206 |
+
model = BaseModel(num_features, num_classes, hidden_sizes=CFG.get("hidden_sizes")).to(device)
|
| 207 |
criterion = nn.CrossEntropyLoss()
|
| 208 |
optimizer = optim.Adam(model.parameters(), lr=CFG["lr"])
|
| 209 |
|
|
|
|
| 222 |
"train_acc": [],
|
| 223 |
"val_loss": [],
|
| 224 |
"val_acc": [],
|
| 225 |
+
"val_f1": [],
|
| 226 |
}
|
| 227 |
|
| 228 |
+
best_val_f1 = 0.0
|
| 229 |
best_val_acc = 0.0
|
| 230 |
|
| 231 |
+
print(f"\n{'Epoch':>6} | {'Train Loss':>10} | {'Train Acc':>9} | {'Val Loss':>10} | {'Val Acc':>9} | {'Val F1':>8}")
|
| 232 |
+
print("-" * 72)
|
| 233 |
|
| 234 |
for epoch in range(1, CFG["epochs"] + 1):
|
| 235 |
train_loss, train_acc = model.training_step(train_loader, optimizer, criterion, device)
|
| 236 |
+
val_loss, val_acc, val_f1 = model.validation_step(val_loader, criterion, device)
|
| 237 |
|
| 238 |
history["epochs"].append(epoch)
|
| 239 |
history["train_loss"].append(round(train_loss, 4))
|
| 240 |
history["train_acc"].append(round(train_acc, 4))
|
| 241 |
history["val_loss"].append(round(val_loss, 4))
|
| 242 |
history["val_acc"].append(round(val_acc, 4))
|
| 243 |
+
history["val_f1"].append(round(val_f1, 4))
|
| 244 |
|
| 245 |
|
| 246 |
current_lr = optimizer.param_groups[0]['lr']
|
|
|
|
| 249 |
task.logger.report_scalar("Accuracy", "Train", float(train_acc), iteration=epoch)
|
| 250 |
task.logger.report_scalar("Loss", "Val", float(val_loss), iteration=epoch)
|
| 251 |
task.logger.report_scalar("Accuracy", "Val", float(val_acc), iteration=epoch)
|
| 252 |
+
task.logger.report_scalar("F1", "Val", float(val_f1), iteration=epoch)
|
| 253 |
task.logger.report_scalar("Learning Rate", "LR", float(current_lr), iteration=epoch)
|
| 254 |
task.logger.flush()
|
| 255 |
|
| 256 |
marker = ""
|
| 257 |
+
if val_f1 > best_val_f1:
|
| 258 |
+
best_val_f1 = val_f1
|
| 259 |
best_val_acc = val_acc
|
| 260 |
torch.save(model.state_dict(), best_ckpt_path)
|
| 261 |
marker = " *"
|
| 262 |
|
| 263 |
+
print(
|
| 264 |
+
f"{epoch:>6} | {train_loss:>10.4f} | {train_acc:>8.2%} | {val_loss:>10.4f} | "
|
| 265 |
+
f"{val_acc:>8.2%} | {val_f1:>8.4f}{marker}"
|
| 266 |
+
)
|
| 267 |
|
| 268 |
+
print(f"\nBest validation F1: {best_val_f1:.4f} (accuracy at best F1: {best_val_acc:.2%})")
|
| 269 |
print(f"Checkpoint saved to: {best_ckpt_path}")
|
| 270 |
|
| 271 |
model.load_state_dict(torch.load(best_ckpt_path, weights_only=True))
|
| 272 |
test_loss, test_acc, test_probs, test_preds, test_labels = model.test_step(test_loader, criterion, device)
|
| 273 |
+
test_labels_np = np.asarray(test_labels)
|
| 274 |
+
test_preds_np = np.asarray(test_preds)
|
| 275 |
+
|
| 276 |
+
test_f1 = f1_score(test_labels_np, test_preds_np, average="weighted")
|
| 277 |
if num_classes > 2:
|
| 278 |
+
test_auc = roc_auc_score(test_labels_np, test_probs, multi_class="ovr", average="weighted")
|
| 279 |
else:
|
| 280 |
+
test_auc = roc_auc_score(test_labels_np, test_probs[:, 1])
|
| 281 |
+
|
| 282 |
print(f"\n[TEST] Loss: {test_loss:.4f} | Accuracy: {test_acc:.2%}")
|
| 283 |
print(f"[TEST] F1: {test_f1:.4f} | ROC-AUC: {test_auc:.4f}")
|
| 284 |
|
|
|
|
| 287 |
history["test_f1"] = round(test_f1, 4)
|
| 288 |
history["test_auc"] = round(test_auc, 4)
|
| 289 |
|
| 290 |
+
# Dataset stats for ClearML
|
| 291 |
+
train_labels = train_loader.dataset.labels.numpy()
|
| 292 |
+
val_labels = val_loader.dataset.labels.numpy()
|
| 293 |
+
dataset_stats = {
|
| 294 |
+
"train_size": len(train_loader.dataset),
|
| 295 |
+
"val_size": len(val_loader.dataset),
|
| 296 |
+
"test_size": len(test_loader.dataset),
|
| 297 |
+
"train_class_counts": np.bincount(train_labels, minlength=num_classes).tolist(),
|
| 298 |
+
"val_class_counts": np.bincount(val_labels, minlength=num_classes).tolist(),
|
| 299 |
+
"test_class_counts": np.bincount(test_labels_np, minlength=num_classes).tolist(),
|
| 300 |
+
}
|
| 301 |
+
history["dataset_stats"] = dataset_stats
|
| 302 |
+
|
| 303 |
logs_dir = CFG["logs_dir"]
|
| 304 |
os.makedirs(logs_dir, exist_ok=True)
|
| 305 |
log_path = os.path.join(logs_dir, f"{CFG['model_name']}_training_log.json")
|
|
|
|
| 306 |
with open(log_path, "w") as f:
|
| 307 |
json.dump(history, f, indent=2)
|
|
|
|
| 308 |
print(f"[LOG] Training history saved to: {log_path}")
|
| 309 |
|
|
|
|
| 310 |
scaler_path = os.path.join(ckpt_dir, "scaler_mlp.joblib")
|
| 311 |
joblib.dump(scaler, scaler_path)
|
| 312 |
meta_path = os.path.join(ckpt_dir, "meta_mlp.npz")
|
| 313 |
np.savez(meta_path, feature_names=np.array(SELECTED_FEATURES["face_orientation"]))
|
| 314 |
print(f"[LOG] Scaler and meta saved to {ckpt_dir}")
|
| 315 |
|
| 316 |
+
# ClearML: artifacts, confusion matrix, per-class metrics
|
| 317 |
+
if task is not None:
|
| 318 |
+
task.upload_artifact(name="mlp_best", artifact_object=best_ckpt_path)
|
| 319 |
+
task.upload_artifact(name="training_log", artifact_object=log_path)
|
| 320 |
+
task.logger.report_single_value("test/accuracy", test_acc)
|
| 321 |
+
task.logger.report_single_value("test/f1_weighted", test_f1)
|
| 322 |
+
task.logger.report_single_value("test/roc_auc", test_auc)
|
| 323 |
+
for key, val in dataset_stats.items():
|
| 324 |
+
if isinstance(val, list):
|
| 325 |
+
task.logger.report_single_value(f"dataset/{key}", str(val))
|
| 326 |
+
else:
|
| 327 |
+
task.logger.report_single_value(f"dataset/{key}", val)
|
| 328 |
+
prec, rec, f1_per_class, _ = precision_recall_fscore_support(
|
| 329 |
+
test_labels_np, test_preds_np, average=None, zero_division=0
|
| 330 |
+
)
|
| 331 |
+
for c in range(num_classes):
|
| 332 |
+
task.logger.report_single_value(f"test/class_{c}_precision", float(prec[c]))
|
| 333 |
+
task.logger.report_single_value(f"test/class_{c}_recall", float(rec[c]))
|
| 334 |
+
task.logger.report_single_value(f"test/class_{c}_f1", float(f1_per_class[c]))
|
| 335 |
+
cm = confusion_matrix(test_labels_np, test_preds_np)
|
| 336 |
+
import matplotlib
|
| 337 |
+
matplotlib.use("Agg")
|
| 338 |
+
import matplotlib.pyplot as plt
|
| 339 |
+
fig, ax = plt.subplots(figsize=(6, 5))
|
| 340 |
+
ax.imshow(cm, cmap="Blues")
|
| 341 |
+
ax.set_xticks(range(num_classes))
|
| 342 |
+
ax.set_yticks(range(num_classes))
|
| 343 |
+
ax.set_xticklabels([f"Class {i}" for i in range(num_classes)])
|
| 344 |
+
ax.set_yticklabels([f"Class {i}" for i in range(num_classes)])
|
| 345 |
+
for i in range(num_classes):
|
| 346 |
+
for j in range(num_classes):
|
| 347 |
+
ax.text(j, i, str(cm[i, j]), ha="center", va="center", color="black")
|
| 348 |
+
ax.set_xlabel("Predicted")
|
| 349 |
+
ax.set_ylabel("True")
|
| 350 |
+
ax.set_title("Test set confusion matrix")
|
| 351 |
+
fig.tight_layout()
|
| 352 |
+
task.logger.report_matplotlib_figure(title="Confusion Matrix", series="test", figure=fig, iteration=0)
|
| 353 |
+
plt.close(fig)
|
| 354 |
+
task.logger.flush()
|
| 355 |
+
|
| 356 |
|
| 357 |
if __name__ == "__main__":
|
| 358 |
main()
|
models/xgboost/add_accuracy.py
CHANGED
|
@@ -8,9 +8,7 @@ import os
|
|
| 8 |
print("Loading dataset for evaluation...")
|
| 9 |
splits, _, _, _ = get_numpy_splits(
|
| 10 |
model_name="face_orientation",
|
| 11 |
-
|
| 12 |
-
seed=42,
|
| 13 |
-
scale=False
|
| 14 |
)
|
| 15 |
X_train, y_train = splits["X_train"], splits["y_train"]
|
| 16 |
X_val, y_val = splits["X_val"], splits["y_val"]
|
|
|
|
| 8 |
print("Loading dataset for evaluation...")
|
| 9 |
splits, _, _, _ = get_numpy_splits(
|
| 10 |
model_name="face_orientation",
|
| 11 |
+
scale=False,
|
|
|
|
|
|
|
| 12 |
)
|
| 13 |
X_train, y_train = splits["X_train"], splits["y_train"]
|
| 14 |
X_val, y_val = splits["X_val"], splits["y_val"]
|
models/xgboost/config.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Shared XGBoost config used by training and evaluation. Loads from config/default.yaml when present."""
|
| 2 |
+
|
| 3 |
+
from copy import deepcopy
|
| 4 |
+
|
| 5 |
+
from xgboost import XGBClassifier
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def _load_xgb_params():
|
| 9 |
+
try:
|
| 10 |
+
from config import get
|
| 11 |
+
xgb = get("xgboost") or {}
|
| 12 |
+
return {
|
| 13 |
+
"n_estimators": xgb.get("n_estimators", 600),
|
| 14 |
+
"max_depth": xgb.get("max_depth", 8),
|
| 15 |
+
"learning_rate": xgb.get("learning_rate", 0.1489),
|
| 16 |
+
"subsample": xgb.get("subsample", 0.9625),
|
| 17 |
+
"colsample_bytree": xgb.get("colsample_bytree", 0.9013),
|
| 18 |
+
"reg_alpha": xgb.get("reg_alpha", 1.1407),
|
| 19 |
+
"reg_lambda": xgb.get("reg_lambda", 2.4181),
|
| 20 |
+
"eval_metric": xgb.get("eval_metric", "logloss"),
|
| 21 |
+
}
|
| 22 |
+
except Exception:
|
| 23 |
+
return {
|
| 24 |
+
"n_estimators": 600,
|
| 25 |
+
"max_depth": 8,
|
| 26 |
+
"learning_rate": 0.1489,
|
| 27 |
+
"subsample": 0.9625,
|
| 28 |
+
"colsample_bytree": 0.9013,
|
| 29 |
+
"reg_alpha": 1.1407,
|
| 30 |
+
"reg_lambda": 2.4181,
|
| 31 |
+
"eval_metric": "logloss",
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
XGB_BASE_PARAMS = _load_xgb_params()
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def get_xgb_params():
|
| 39 |
+
return deepcopy(XGB_BASE_PARAMS)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def build_xgb_classifier(seed: int, *, verbosity: int = 0, early_stopping_rounds=None):
|
| 43 |
+
params = get_xgb_params()
|
| 44 |
+
params.update(
|
| 45 |
+
{
|
| 46 |
+
"random_state": seed,
|
| 47 |
+
"verbosity": verbosity,
|
| 48 |
+
}
|
| 49 |
+
)
|
| 50 |
+
if early_stopping_rounds is not None:
|
| 51 |
+
params["early_stopping_rounds"] = early_stopping_rounds
|
| 52 |
+
return XGBClassifier(**params)
|
models/xgboost/eval_accuracy.py
CHANGED
|
@@ -25,8 +25,6 @@ def main():
|
|
| 25 |
|
| 26 |
splits, num_features, num_classes, _ = get_numpy_splits(
|
| 27 |
model_name=MODEL_NAME,
|
| 28 |
-
split_ratios=(0.7, 0.15, 0.15),
|
| 29 |
-
seed=42,
|
| 30 |
scale=False,
|
| 31 |
)
|
| 32 |
X_test = splits["X_test"]
|
|
|
|
| 25 |
|
| 26 |
splits, num_features, num_classes, _ = get_numpy_splits(
|
| 27 |
model_name=MODEL_NAME,
|
|
|
|
|
|
|
| 28 |
scale=False,
|
| 29 |
)
|
| 30 |
X_test = splits["X_test"]
|
models/xgboost/sweep_local.py
CHANGED
|
@@ -14,13 +14,12 @@ from xgboost import XGBClassifier
|
|
| 14 |
from sklearn.metrics import f1_score, roc_auc_score, accuracy_score
|
| 15 |
|
| 16 |
# Import your own dataset loading logic
|
| 17 |
-
from data_preparation.prepare_dataset import get_numpy_splits
|
| 18 |
|
| 19 |
# ── General Settings ──────────────────────────────────────────────────────────
|
| 20 |
PROJECT_NAME = "FocusGuards Large Group Project"
|
| 21 |
BASE_TASK_NAME = "XGBoost Sweep Trial"
|
| 22 |
-
DATA_SPLITS = (
|
| 23 |
-
SEED = 42
|
| 24 |
|
| 25 |
# ── Search Space ──────────────────────────────────────────────────────────────
|
| 26 |
def objective(trial):
|
|
|
|
| 14 |
from sklearn.metrics import f1_score, roc_auc_score, accuracy_score
|
| 15 |
|
| 16 |
# Import your own dataset loading logic
|
| 17 |
+
from data_preparation.prepare_dataset import get_default_split_config, get_numpy_splits
|
| 18 |
|
| 19 |
# ── General Settings ──────────────────────────────────────────────────────────
|
| 20 |
PROJECT_NAME = "FocusGuards Large Group Project"
|
| 21 |
BASE_TASK_NAME = "XGBoost Sweep Trial"
|
| 22 |
+
DATA_SPLITS, SEED = get_default_split_config()
|
|
|
|
| 23 |
|
| 24 |
# ── Search Space ──────────────────────────────────────────────────────────────
|
| 25 |
def objective(trial):
|
models/xgboost/train.py
CHANGED
|
@@ -1,42 +1,73 @@
|
|
| 1 |
import json
|
| 2 |
import os
|
| 3 |
import random
|
|
|
|
| 4 |
|
| 5 |
import numpy as np
|
| 6 |
-
|
| 7 |
-
from sklearn.metrics import f1_score, roc_auc_score
|
| 8 |
-
from xgboost import XGBClassifier
|
| 9 |
|
| 10 |
from data_preparation.prepare_dataset import get_numpy_splits
|
|
|
|
| 11 |
|
| 12 |
_PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
}
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
task = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
|
| 41 |
def set_seed(seed: int):
|
| 42 |
random.seed(seed)
|
|
@@ -62,19 +93,7 @@ def main():
|
|
| 62 |
X_test, y_test = splits["X_test"], splits["y_test"]
|
| 63 |
|
| 64 |
# ── Model ─────────────────────────────────────────────────────
|
| 65 |
-
model =
|
| 66 |
-
n_estimators=CFG["n_estimators"],
|
| 67 |
-
max_depth=CFG["max_depth"],
|
| 68 |
-
learning_rate=CFG["learning_rate"],
|
| 69 |
-
subsample=CFG["subsample"],
|
| 70 |
-
colsample_bytree=CFG["colsample_bytree"],
|
| 71 |
-
reg_alpha=CFG["reg_alpha"],
|
| 72 |
-
reg_lambda=CFG["reg_lambda"],
|
| 73 |
-
eval_metric=CFG["eval_metric"],
|
| 74 |
-
early_stopping_rounds=30,
|
| 75 |
-
random_state=CFG["seed"],
|
| 76 |
-
verbosity=1,
|
| 77 |
-
)
|
| 78 |
|
| 79 |
model.fit(
|
| 80 |
X_train, y_train,
|
|
@@ -82,12 +101,13 @@ def main():
|
|
| 82 |
verbose=10,
|
| 83 |
)
|
| 84 |
best_it = getattr(model, "best_iteration", None)
|
| 85 |
-
print(f"[TRAIN] Best iteration: {best_it} / {CFG['n_estimators']}")
|
| 86 |
|
| 87 |
# ── Evaluation ────────────────────────────────────────────────
|
| 88 |
evals = model.evals_result()
|
| 89 |
-
|
| 90 |
-
|
|
|
|
| 91 |
|
| 92 |
# Test metrics
|
| 93 |
test_preds = model.predict(X_test)
|
|
@@ -104,14 +124,53 @@ def main():
|
|
| 104 |
print(f"[TEST] F1: {test_f1:.4f}")
|
| 105 |
print(f"[TEST] ROC-AUC: {test_auc:.4f}")
|
| 106 |
|
| 107 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
if task is not None:
|
| 109 |
for i, (tl, vl) in enumerate(zip(train_losses, val_losses)):
|
| 110 |
task.logger.report_scalar("Loss", "Train", tl, iteration=i + 1)
|
| 111 |
-
task.logger.report_scalar("Loss", "Val",
|
| 112 |
-
task.logger.report_single_value("
|
| 113 |
-
task.logger.report_single_value("
|
| 114 |
-
task.logger.report_single_value("
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
task.logger.flush()
|
| 116 |
|
| 117 |
# ── Save checkpoint ───────────────────────────────────────────
|
|
@@ -122,17 +181,23 @@ def main():
|
|
| 122 |
print(f"\n[CKPT] Model saved to: {model_path}")
|
| 123 |
|
| 124 |
# ── Write JSON log (same schema as MLP) ───────────────────────
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
history = {
|
| 126 |
"model_name": f"xgboost_{CFG['model_name']}",
|
| 127 |
-
"param_count":
|
| 128 |
-
"
|
| 129 |
-
"
|
| 130 |
"epochs": list(range(1, len(train_losses) + 1)),
|
| 131 |
"train_loss": [round(v, 4) for v in train_losses],
|
| 132 |
-
"val_loss":
|
| 133 |
-
"test_acc":
|
| 134 |
-
"test_f1":
|
| 135 |
-
"test_auc":
|
|
|
|
| 136 |
}
|
| 137 |
|
| 138 |
logs_dir = CFG["logs_dir"]
|
|
@@ -144,6 +209,10 @@ def main():
|
|
| 144 |
|
| 145 |
print(f"[LOG] Training history saved to: {log_path}")
|
| 146 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
|
| 148 |
if __name__ == "__main__":
|
| 149 |
main()
|
|
|
|
| 1 |
import json
|
| 2 |
import os
|
| 3 |
import random
|
| 4 |
+
import sys
|
| 5 |
|
| 6 |
import numpy as np
|
| 7 |
+
from sklearn.metrics import confusion_matrix, f1_score, precision_recall_fscore_support, roc_auc_score
|
|
|
|
|
|
|
| 8 |
|
| 9 |
from data_preparation.prepare_dataset import get_numpy_splits
|
| 10 |
+
from models.xgboost.config import XGB_BASE_PARAMS, build_xgb_classifier
|
| 11 |
|
| 12 |
_PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def _load_cfg():
|
| 16 |
+
try:
|
| 17 |
+
from config import get
|
| 18 |
+
xgb = get("xgboost") or {}
|
| 19 |
+
data = get("data") or {}
|
| 20 |
+
ratios = data.get("split_ratios", [0.7, 0.15, 0.15])
|
| 21 |
+
return {
|
| 22 |
+
"model_name": get("mlp.model_name") or "face_orientation",
|
| 23 |
+
"seed": get("mlp.seed") or 42,
|
| 24 |
+
"split_ratios": tuple(ratios),
|
| 25 |
+
"scale": False,
|
| 26 |
+
"checkpoints_dir": os.path.join(_PROJECT_ROOT, "checkpoints"),
|
| 27 |
+
"logs_dir": os.path.join(_PROJECT_ROOT, "evaluation", "logs"),
|
| 28 |
+
"xgb_params": dict(XGB_BASE_PARAMS),
|
| 29 |
+
}
|
| 30 |
+
except Exception:
|
| 31 |
+
return {
|
| 32 |
+
"model_name": "face_orientation",
|
| 33 |
+
"seed": 42,
|
| 34 |
+
"split_ratios": (0.7, 0.15, 0.15),
|
| 35 |
+
"scale": False,
|
| 36 |
+
"checkpoints_dir": os.path.join(_PROJECT_ROOT, "checkpoints"),
|
| 37 |
+
"logs_dir": os.path.join(_PROJECT_ROOT, "evaluation", "logs"),
|
| 38 |
+
"xgb_params": dict(XGB_BASE_PARAMS),
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
CFG = _load_cfg()
|
| 43 |
+
|
| 44 |
+
USE_CLEARML = os.environ.get("USE_CLEARML", "0") == "1" or bool(os.environ.get("CLEARML_TASK_ID"))
|
| 45 |
+
CLEARML_QUEUE = os.environ.get("CLEARML_QUEUE", "")
|
| 46 |
+
|
| 47 |
task = None
|
| 48 |
+
if USE_CLEARML:
|
| 49 |
+
try:
|
| 50 |
+
from clearml import Task
|
| 51 |
+
from config import flatten_for_clearml
|
| 52 |
+
task = Task.init(
|
| 53 |
+
project_name="Focus Guard",
|
| 54 |
+
task_name="XGBoost Model Training",
|
| 55 |
+
tags=["training", "xgboost"],
|
| 56 |
+
)
|
| 57 |
+
flat = flatten_for_clearml()
|
| 58 |
+
for k, v in CFG.get("xgb_params", {}).items():
|
| 59 |
+
flat[f"xgb_params/{k}"] = v
|
| 60 |
+
flat["model_name"] = CFG["model_name"]
|
| 61 |
+
flat["seed"] = CFG["seed"]
|
| 62 |
+
flat["split_ratios"] = str(CFG["split_ratios"])
|
| 63 |
+
task.connect(flat)
|
| 64 |
+
if CLEARML_QUEUE:
|
| 65 |
+
print(f"[ClearML] Enqueuing to queue '{CLEARML_QUEUE}'.")
|
| 66 |
+
task.execute_remotely(queue_name=CLEARML_QUEUE)
|
| 67 |
+
sys.exit(0)
|
| 68 |
+
except ImportError:
|
| 69 |
+
task = None
|
| 70 |
+
USE_CLEARML = False
|
| 71 |
|
| 72 |
def set_seed(seed: int):
|
| 73 |
random.seed(seed)
|
|
|
|
| 93 |
X_test, y_test = splits["X_test"], splits["y_test"]
|
| 94 |
|
| 95 |
# ── Model ─────────────────────────────────────────────────────
|
| 96 |
+
model = build_xgb_classifier(CFG["seed"], verbosity=1, early_stopping_rounds=30)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
|
| 98 |
model.fit(
|
| 99 |
X_train, y_train,
|
|
|
|
| 101 |
verbose=10,
|
| 102 |
)
|
| 103 |
best_it = getattr(model, "best_iteration", None)
|
| 104 |
+
print(f"[TRAIN] Best iteration: {best_it} / {CFG['xgb_params']['n_estimators']}")
|
| 105 |
|
| 106 |
# ── Evaluation ────────────────────────────────────────────────
|
| 107 |
evals = model.evals_result()
|
| 108 |
+
eval_metric_name = CFG["xgb_params"]["eval_metric"]
|
| 109 |
+
train_losses = evals["validation_0"][eval_metric_name]
|
| 110 |
+
val_losses = evals["validation_1"][eval_metric_name]
|
| 111 |
|
| 112 |
# Test metrics
|
| 113 |
test_preds = model.predict(X_test)
|
|
|
|
| 124 |
print(f"[TEST] F1: {test_f1:.4f}")
|
| 125 |
print(f"[TEST] ROC-AUC: {test_auc:.4f}")
|
| 126 |
|
| 127 |
+
# Dataset stats
|
| 128 |
+
dataset_stats = {
|
| 129 |
+
"train_size": len(y_train),
|
| 130 |
+
"val_size": len(y_val),
|
| 131 |
+
"test_size": len(y_test),
|
| 132 |
+
"train_class_counts": np.bincount(y_train.astype(int), minlength=num_classes).tolist(),
|
| 133 |
+
"val_class_counts": np.bincount(y_val.astype(int), minlength=num_classes).tolist(),
|
| 134 |
+
"test_class_counts": np.bincount(y_test.astype(int), minlength=num_classes).tolist(),
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
if task is not None:
|
| 138 |
for i, (tl, vl) in enumerate(zip(train_losses, val_losses)):
|
| 139 |
task.logger.report_scalar("Loss", "Train", tl, iteration=i + 1)
|
| 140 |
+
task.logger.report_scalar("Loss", "Val", vl, iteration=i + 1)
|
| 141 |
+
task.logger.report_single_value("test/accuracy", test_acc)
|
| 142 |
+
task.logger.report_single_value("test/f1_weighted", test_f1)
|
| 143 |
+
task.logger.report_single_value("test/roc_auc", test_auc)
|
| 144 |
+
for key, val in dataset_stats.items():
|
| 145 |
+
task.logger.report_single_value(
|
| 146 |
+
f"dataset/{key}", str(val) if isinstance(val, list) else val
|
| 147 |
+
)
|
| 148 |
+
prec, rec, f1_per_class, _ = precision_recall_fscore_support(
|
| 149 |
+
y_test, test_preds, average=None, zero_division=0
|
| 150 |
+
)
|
| 151 |
+
for c in range(num_classes):
|
| 152 |
+
task.logger.report_single_value(f"test/class_{c}_precision", float(prec[c]))
|
| 153 |
+
task.logger.report_single_value(f"test/class_{c}_recall", float(rec[c]))
|
| 154 |
+
task.logger.report_single_value(f"test/class_{c}_f1", float(f1_per_class[c]))
|
| 155 |
+
cm = confusion_matrix(y_test, test_preds)
|
| 156 |
+
import matplotlib
|
| 157 |
+
matplotlib.use("Agg")
|
| 158 |
+
import matplotlib.pyplot as plt
|
| 159 |
+
fig, ax = plt.subplots(figsize=(6, 5))
|
| 160 |
+
ax.imshow(cm, cmap="Blues")
|
| 161 |
+
ax.set_xticks(range(num_classes))
|
| 162 |
+
ax.set_yticks(range(num_classes))
|
| 163 |
+
ax.set_xticklabels([f"Class {i}" for i in range(num_classes)])
|
| 164 |
+
ax.set_yticklabels([f"Class {i}" for i in range(num_classes)])
|
| 165 |
+
for i in range(num_classes):
|
| 166 |
+
for j in range(num_classes):
|
| 167 |
+
ax.text(j, i, str(cm[i, j]), ha="center", va="center", color="black")
|
| 168 |
+
ax.set_xlabel("Predicted")
|
| 169 |
+
ax.set_ylabel("True")
|
| 170 |
+
ax.set_title("Test set confusion matrix")
|
| 171 |
+
fig.tight_layout()
|
| 172 |
+
task.logger.report_matplotlib_figure(title="Confusion Matrix", series="test", figure=fig, iteration=0)
|
| 173 |
+
plt.close(fig)
|
| 174 |
task.logger.flush()
|
| 175 |
|
| 176 |
# ── Save checkpoint ───────────────────────────────────────────
|
|
|
|
| 181 |
print(f"\n[CKPT] Model saved to: {model_path}")
|
| 182 |
|
| 183 |
# ── Write JSON log (same schema as MLP) ───────────────────────
|
| 184 |
+
# pandas-free tree/node count (trees_to_dataframe() needs pandas)
|
| 185 |
+
booster = model.get_booster()
|
| 186 |
+
tree_count = int(booster.num_boosted_rounds())
|
| 187 |
+
node_count = int(sum(tree.count("\n") + 1 for tree in booster.get_dump()))
|
| 188 |
+
|
| 189 |
history = {
|
| 190 |
"model_name": f"xgboost_{CFG['model_name']}",
|
| 191 |
+
"param_count": node_count,
|
| 192 |
+
"tree_count": tree_count,
|
| 193 |
+
"xgb_params": CFG["xgb_params"],
|
| 194 |
"epochs": list(range(1, len(train_losses) + 1)),
|
| 195 |
"train_loss": [round(v, 4) for v in train_losses],
|
| 196 |
+
"val_loss": [round(v, 4) for v in val_losses],
|
| 197 |
+
"test_acc": round(test_acc, 4),
|
| 198 |
+
"test_f1": round(test_f1, 4),
|
| 199 |
+
"test_auc": round(test_auc, 4),
|
| 200 |
+
"dataset_stats": dataset_stats,
|
| 201 |
}
|
| 202 |
|
| 203 |
logs_dir = CFG["logs_dir"]
|
|
|
|
| 209 |
|
| 210 |
print(f"[LOG] Training history saved to: {log_path}")
|
| 211 |
|
| 212 |
+
if task is not None:
|
| 213 |
+
task.upload_artifact(name="xgboost_model", artifact_object=model_path)
|
| 214 |
+
task.upload_artifact(name="training_log", artifact_object=log_path)
|
| 215 |
+
|
| 216 |
|
| 217 |
if __name__ == "__main__":
|
| 218 |
main()
|
requirements.txt
CHANGED
|
@@ -16,6 +16,7 @@ httpx>=0.27.0
|
|
| 16 |
aiosqlite>=0.19.0
|
| 17 |
psutil>=5.9.0
|
| 18 |
pydantic>=2.0.0
|
|
|
|
| 19 |
xgboost>=2.0.0
|
| 20 |
clearml>=2.0.2
|
| 21 |
pytest>=9.0.0
|
|
|
|
| 16 |
aiosqlite>=0.19.0
|
| 17 |
psutil>=5.9.0
|
| 18 |
pydantic>=2.0.0
|
| 19 |
+
PyYAML>=6.0
|
| 20 |
xgboost>=2.0.0
|
| 21 |
clearml>=2.0.2
|
| 22 |
pytest>=9.0.0
|
src/App.jsx
CHANGED
|
@@ -65,7 +65,7 @@ function App() {
|
|
| 65 |
{renderMenuButton('records', 'My Records')}
|
| 66 |
<div className="separator"></div>
|
| 67 |
|
| 68 |
-
{renderMenuButton('customise', '
|
| 69 |
<div className="separator"></div>
|
| 70 |
|
| 71 |
{renderMenuButton('help', 'Help')}
|
|
|
|
| 65 |
{renderMenuButton('records', 'My Records')}
|
| 66 |
<div className="separator"></div>
|
| 67 |
|
| 68 |
+
{renderMenuButton('customise', 'Settings')}
|
| 69 |
<div className="separator"></div>
|
| 70 |
|
| 71 |
{renderMenuButton('help', 'Help')}
|
src/components/Achievement.jsx
CHANGED
|
@@ -199,22 +199,6 @@ function Achievement() {
|
|
| 199 |
</div>
|
| 200 |
) : (
|
| 201 |
<>
|
| 202 |
-
{systemStats && systemStats.cpu_percent != null && (
|
| 203 |
-
<div style={{
|
| 204 |
-
textAlign: 'center',
|
| 205 |
-
marginBottom: '12px',
|
| 206 |
-
padding: '8px 12px',
|
| 207 |
-
background: 'rgba(0,0,0,0.2)',
|
| 208 |
-
borderRadius: '8px',
|
| 209 |
-
fontSize: '13px',
|
| 210 |
-
color: '#aaa'
|
| 211 |
-
}}>
|
| 212 |
-
Server: CPU <strong style={{ color: '#8f8' }}>{systemStats.cpu_percent}%</strong>
|
| 213 |
-
{' · '}
|
| 214 |
-
RAM <strong style={{ color: '#8af' }}>{systemStats.memory_percent}%</strong>
|
| 215 |
-
{systemStats.memory_used_mb != null && ` (${systemStats.memory_used_mb}/${systemStats.memory_total_mb} MB)`}
|
| 216 |
-
</div>
|
| 217 |
-
)}
|
| 218 |
<div className="stats-grid">
|
| 219 |
<div className="stat-card">
|
| 220 |
<div className="stat-number" id="total-sessions">{stats.total_sessions}</div>
|
|
|
|
| 199 |
</div>
|
| 200 |
) : (
|
| 201 |
<>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 202 |
<div className="stats-grid">
|
| 203 |
<div className="stat-card">
|
| 204 |
<div className="stat-number" id="total-sessions">{stats.total_sessions}</div>
|
src/components/Customise.jsx
CHANGED
|
@@ -103,7 +103,7 @@ function Customise() {
|
|
| 103 |
|
| 104 |
return (
|
| 105 |
<main id="page-e" className="page">
|
| 106 |
-
<h1 className="page-title">
|
| 107 |
|
| 108 |
<div className="settings-container">
|
| 109 |
{/* Data Management */}
|
|
|
|
| 103 |
|
| 104 |
return (
|
| 105 |
<main id="page-e" className="page">
|
| 106 |
+
<h1 className="page-title">Settings</h1>
|
| 107 |
|
| 108 |
<div className="settings-container">
|
| 109 |
{/* Data Management */}
|
src/components/FocusPageLocal.jsx
CHANGED
|
@@ -518,20 +518,16 @@ function FocusPageLocal({ videoManager, sessionResult, setSessionResult, isActiv
|
|
| 518 |
return;
|
| 519 |
}
|
| 520 |
|
| 521 |
-
//
|
| 522 |
const sessionDuration = Math.floor((Date.now() - (videoManager.sessionStartTime || Date.now())) / 1000);
|
|
|
|
|
|
|
|
|
|
| 523 |
|
| 524 |
-
//
|
| 525 |
-
const focusScore = currentStats.framesProcessed > 0
|
| 526 |
-
? (currentStats.framesProcessed * (currentStats.currentStatus ? 1 : 0)) / currentStats.framesProcessed
|
| 527 |
-
: 0;
|
| 528 |
-
|
| 529 |
-
//
|
| 530 |
setSessionResult({
|
| 531 |
duration_seconds: sessionDuration,
|
| 532 |
focus_score: focusScore,
|
| 533 |
-
total_frames:
|
| 534 |
-
focused_frames:
|
| 535 |
});
|
| 536 |
};
|
| 537 |
|
|
|
|
| 518 |
return;
|
| 519 |
}
|
| 520 |
|
|
|
|
| 521 |
const sessionDuration = Math.floor((Date.now() - (videoManager.sessionStartTime || Date.now())) / 1000);
|
| 522 |
+
const totalFrames = currentStats.framesProcessed || 0;
|
| 523 |
+
const focusedFrames = currentStats.focusedFrames ?? 0;
|
| 524 |
+
const focusScore = totalFrames > 0 ? focusedFrames / totalFrames : 0;
|
| 525 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 526 |
setSessionResult({
|
| 527 |
duration_seconds: sessionDuration,
|
| 528 |
focus_score: focusScore,
|
| 529 |
+
total_frames: totalFrames,
|
| 530 |
+
focused_frames: focusedFrames
|
| 531 |
});
|
| 532 |
};
|
| 533 |
|
src/utils/VideoManagerLocal.js
CHANGED
|
@@ -59,10 +59,11 @@ export class VideoManagerLocal {
|
|
| 59 |
// Calibration state
|
| 60 |
this.calibration = createCalibrationState();
|
| 61 |
|
| 62 |
-
// Performance metrics
|
| 63 |
this.stats = {
|
| 64 |
framesSent: 0,
|
| 65 |
framesProcessed: 0,
|
|
|
|
| 66 |
avgLatency: 0,
|
| 67 |
lastLatencies: []
|
| 68 |
};
|
|
@@ -411,6 +412,8 @@ export class VideoManagerLocal {
|
|
| 411 |
case 'session_started':
|
| 412 |
this.sessionId = data.session_id;
|
| 413 |
this.sessionStartTime = Date.now();
|
|
|
|
|
|
|
| 414 |
console.log('Session started:', this.sessionId);
|
| 415 |
if (this.callbacks.onSessionStart) {
|
| 416 |
this.callbacks.onSessionStart(this.sessionId);
|
|
@@ -419,6 +422,7 @@ export class VideoManagerLocal {
|
|
| 419 |
|
| 420 |
case 'detection':
|
| 421 |
this.stats.framesProcessed++;
|
|
|
|
| 422 |
|
| 423 |
// Track latency from send→receive
|
| 424 |
const now = performance.now();
|
|
|
|
| 59 |
// Calibration state
|
| 60 |
this.calibration = createCalibrationState();
|
| 61 |
|
| 62 |
+
// Performance metrics (focusedFrames = count of frames where focused was true this session)
|
| 63 |
this.stats = {
|
| 64 |
framesSent: 0,
|
| 65 |
framesProcessed: 0,
|
| 66 |
+
focusedFrames: 0,
|
| 67 |
avgLatency: 0,
|
| 68 |
lastLatencies: []
|
| 69 |
};
|
|
|
|
| 412 |
case 'session_started':
|
| 413 |
this.sessionId = data.session_id;
|
| 414 |
this.sessionStartTime = Date.now();
|
| 415 |
+
this.stats.framesProcessed = 0;
|
| 416 |
+
this.stats.focusedFrames = 0;
|
| 417 |
console.log('Session started:', this.sessionId);
|
| 418 |
if (this.callbacks.onSessionStart) {
|
| 419 |
this.callbacks.onSessionStart(this.sessionId);
|
|
|
|
| 422 |
|
| 423 |
case 'detection':
|
| 424 |
this.stats.framesProcessed++;
|
| 425 |
+
if (data.focused) this.stats.focusedFrames++;
|
| 426 |
|
| 427 |
// Track latency from send→receive
|
| 428 |
const now = performance.now();
|
tests/test_api_settings.py
CHANGED
|
@@ -24,27 +24,16 @@ def test_get_settings_default_fields():
|
|
| 24 |
resp = client.get("/api/settings")
|
| 25 |
assert resp.status_code == 200
|
| 26 |
data = resp.json()
|
| 27 |
-
assert "sensitivity" in data
|
| 28 |
-
assert "notification_enabled" in data
|
| 29 |
-
assert "notification_threshold" in data
|
| 30 |
-
assert "frame_rate" in data
|
| 31 |
assert "model_name" in data
|
|
|
|
| 32 |
|
| 33 |
|
| 34 |
-
def
|
| 35 |
client = _make_test_client()
|
| 36 |
with client:
|
| 37 |
-
# get setting
|
| 38 |
r0 = client.get("/api/settings")
|
| 39 |
assert r0.status_code == 200
|
| 40 |
-
|
| 41 |
-
# set unlogic params
|
| 42 |
-
payload = {
|
| 43 |
-
"sensitivity": 100,
|
| 44 |
-
"notification_enabled": False,
|
| 45 |
-
"notification_threshold": 1,
|
| 46 |
-
"frame_rate": 1000,
|
| 47 |
-
}
|
| 48 |
r_put = client.put("/api/settings", json=payload)
|
| 49 |
assert r_put.status_code == 200
|
| 50 |
body = r_put.json()
|
|
@@ -53,8 +42,5 @@ def test_update_settings_clamped_ranges():
|
|
| 53 |
|
| 54 |
r1 = client.get("/api/settings")
|
| 55 |
data = r1.json()
|
| 56 |
-
assert
|
| 57 |
-
assert bool(data["notification_enabled"]) is False
|
| 58 |
-
assert 5 <= data["notification_threshold"] <= 300
|
| 59 |
-
assert 5 <= data["frame_rate"] <= 60
|
| 60 |
|
|
|
|
| 24 |
resp = client.get("/api/settings")
|
| 25 |
assert resp.status_code == 200
|
| 26 |
data = resp.json()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
assert "model_name" in data
|
| 28 |
+
assert "l2cs_boost" in data
|
| 29 |
|
| 30 |
|
| 31 |
+
def test_update_settings_model_name():
|
| 32 |
client = _make_test_client()
|
| 33 |
with client:
|
|
|
|
| 34 |
r0 = client.get("/api/settings")
|
| 35 |
assert r0.status_code == 200
|
| 36 |
+
payload = {"model_name": "xgboost"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
r_put = client.put("/api/settings", json=payload)
|
| 38 |
assert r_put.status_code == 200
|
| 39 |
body = r_put.json()
|
|
|
|
| 42 |
|
| 43 |
r1 = client.get("/api/settings")
|
| 44 |
data = r1.json()
|
| 45 |
+
assert data["model_name"] == "xgboost"
|
|
|
|
|
|
|
|
|
|
| 46 |
|
tests/test_data_preparation.py
CHANGED
|
@@ -10,10 +10,18 @@ if PROJECT_ROOT not in sys.path:
|
|
| 10 |
from data_preparation.prepare_dataset import (
|
| 11 |
SELECTED_FEATURES,
|
| 12 |
_generate_synthetic_data,
|
|
|
|
| 13 |
get_numpy_splits,
|
| 14 |
)
|
| 15 |
|
| 16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
def test_generate_synthetic_data_shape():
|
| 18 |
X, y = _generate_synthetic_data("face_orientation")
|
| 19 |
assert X.shape[0] == 500
|
|
@@ -22,18 +30,23 @@ def test_generate_synthetic_data_shape():
|
|
| 22 |
|
| 23 |
|
| 24 |
def test_get_numpy_splits_consistency():
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
-
# train/val/test each have samples
|
| 28 |
n_train = len(splits["y_train"])
|
| 29 |
n_val = len(splits["y_val"])
|
| 30 |
n_test = len(splits["y_test"])
|
| 31 |
assert n_train > 0
|
| 32 |
assert n_val > 0
|
| 33 |
assert n_test > 0
|
| 34 |
-
|
| 35 |
-
# feature dim should same as num_features
|
| 36 |
assert splits["X_train"].shape[1] == num_features
|
| 37 |
-
|
| 38 |
assert num_classes >= 2
|
| 39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
from data_preparation.prepare_dataset import (
|
| 11 |
SELECTED_FEATURES,
|
| 12 |
_generate_synthetic_data,
|
| 13 |
+
get_default_split_config,
|
| 14 |
get_numpy_splits,
|
| 15 |
)
|
| 16 |
|
| 17 |
|
| 18 |
+
def test_get_default_split_config():
|
| 19 |
+
ratios, seed = get_default_split_config()
|
| 20 |
+
assert len(ratios) == 3
|
| 21 |
+
assert abs(sum(ratios) - 1.0) < 1e-6
|
| 22 |
+
assert seed >= 0
|
| 23 |
+
|
| 24 |
+
|
| 25 |
def test_generate_synthetic_data_shape():
|
| 26 |
X, y = _generate_synthetic_data("face_orientation")
|
| 27 |
assert X.shape[0] == 500
|
|
|
|
| 30 |
|
| 31 |
|
| 32 |
def test_get_numpy_splits_consistency():
|
| 33 |
+
split_ratios, seed = get_default_split_config()
|
| 34 |
+
splits, num_features, num_classes, scaler = get_numpy_splits(
|
| 35 |
+
"face_orientation", split_ratios=split_ratios, seed=seed
|
| 36 |
+
)
|
| 37 |
|
|
|
|
| 38 |
n_train = len(splits["y_train"])
|
| 39 |
n_val = len(splits["y_val"])
|
| 40 |
n_test = len(splits["y_test"])
|
| 41 |
assert n_train > 0
|
| 42 |
assert n_val > 0
|
| 43 |
assert n_test > 0
|
|
|
|
|
|
|
| 44 |
assert splits["X_train"].shape[1] == num_features
|
|
|
|
| 45 |
assert num_classes >= 2
|
| 46 |
|
| 47 |
+
# Same seed and ratios produce same split (deterministic)
|
| 48 |
+
splits2, _, _, _ = get_numpy_splits(
|
| 49 |
+
"face_orientation", split_ratios=split_ratios, seed=seed
|
| 50 |
+
)
|
| 51 |
+
np.testing.assert_array_equal(splits["y_test"], splits2["y_test"])
|
| 52 |
+
|