Yingtao Zheng (k23158987) commited on
Commit
52d831a
·
unverified ·
2 Parent(s): 28d0d9e2ea6266

Merge pull request #6 from k23172173/main

Browse files

Update Dev to newest version, correct wrong branch mistakes

This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +0 -16
  2. models/face_orientation_model/best_model.pt → MLP/models/meta_20260224_024200.npz +2 -2
  3. MLP/models/mlp_20260224_024200.joblib +3 -0
  4. MLP/models/scaler_20260224_024200.joblib +3 -0
  5. README.md +7 -7
  6. data_preparation/CNN/eye_crops/val/open/.gitkeep +1 -0
  7. data_preparation/MLP/explore_collected_data.ipynb +0 -0
  8. data_preparation/MLP/train_mlp.ipynb +0 -0
  9. data_preparation/README.md +40 -2
  10. data_preparation/collected_Abdelrahman/abdelrahman_20260306_023035.npz +3 -0
  11. data_preparation/collected_Ayten/ayten_session_1.npz +3 -0
  12. data_preparation/collected_Jarek/Jarek_20260225_012931.npz +3 -0
  13. data_preparation/collected_Junhao/Junhao_20260303_113554.npz +3 -0
  14. data_preparation/collected_Kexin/kexin2_20260305_180229.npz +3 -0
  15. data_preparation/collected_Kexin/kexin_20260224_151043.npz +3 -0
  16. data_preparation/collected_Langyuan/Langyuan_20260303_153145.npz +3 -0
  17. data_preparation/collected_Mohamed/session_20260224_010131.npz +3 -0
  18. data_preparation/collected_Saba/saba_20260306_230710.npz +3 -0
  19. data_preparation/collected_Yingtao/Yingtao_20260306_023937.npz +3 -0
  20. evaluation/README.md +1 -1
  21. {models/attention_score_fusion → evaluation/logs}/.gitkeep +0 -0
  22. models/README.md +7 -5
  23. models/attention/__init__.py +1 -0
  24. models/{eye_behaviour_model/.gitkeep → attention/classifier.py} +0 -0
  25. models/attention/collect_features.py +349 -0
  26. models/{face_landmarks_pretrained/.gitkeep → attention/fusion.py} +0 -0
  27. models/{face_orientation_model/.gitkeep → attention/train.py} +0 -0
  28. models/cnn/CNN_MODEL/.claude/settings.local.json +7 -0
  29. models/cnn/CNN_MODEL/.gitattributes +1 -0
  30. models/cnn/CNN_MODEL/.gitignore +4 -0
  31. models/cnn/CNN_MODEL/README.md +74 -0
  32. models/cnn/CNN_MODEL/notebooks/eye_classifier_colab.ipynb +0 -0
  33. models/cnn/CNN_MODEL/scripts/focus_infer.py +199 -0
  34. models/cnn/CNN_MODEL/scripts/predict_image.py +49 -0
  35. models/cnn/CNN_MODEL/scripts/video_infer.py +281 -0
  36. models/cnn/CNN_MODEL/scripts/webcam_live.py +184 -0
  37. models/cnn/CNN_MODEL/weights/yolo11s-cls.pt +3 -0
  38. models/cnn/__init__.py +0 -0
  39. models/cnn/eye_attention/__init__.py +1 -0
  40. models/cnn/eye_attention/classifier.py +69 -0
  41. models/cnn/eye_attention/crop.py +70 -0
  42. models/cnn/eye_attention/train.py +0 -0
  43. models/geometric/__init__.py +0 -0
  44. models/geometric/eye_behaviour/__init__.py +0 -0
  45. models/geometric/eye_behaviour/eye_scorer.py +164 -0
  46. models/geometric/face_orientation/__init__.py +1 -0
  47. models/geometric/face_orientation/head_pose.py +112 -0
  48. models/mlp/__init__.py +0 -0
  49. models/{train.py → mlp/train.py} +32 -7
  50. models/pretrained/__init__.py +0 -0
.gitignore CHANGED
@@ -1,4 +1,3 @@
1
- # Python
2
  __pycache__/
3
  *.py[cod]
4
  *$py.class
@@ -12,25 +11,10 @@ env/
12
  .eggs/
13
  dist/
14
  build/
15
-
16
- # IDE
17
  .idea/
18
  .vscode/
19
  *.swp
20
  *.swo
21
-
22
- # Data and outputs (optional: uncomment if you don’t want to track large files)
23
- # data_preparation/raw/
24
- # data_preparation/processed/*.npy
25
- # evaluation/logs/
26
- # evaluation/results/
27
-
28
- # Model checkpoints (uncomment to ignore .pt files)
29
- # *.pt
30
-
31
- # Project
32
  docs/
33
-
34
- # OS
35
  .DS_Store
36
  Thumbs.db
 
 
1
  __pycache__/
2
  *.py[cod]
3
  *$py.class
 
11
  .eggs/
12
  dist/
13
  build/
 
 
14
  .idea/
15
  .vscode/
16
  *.swp
17
  *.swo
 
 
 
 
 
 
 
 
 
 
 
18
  docs/
 
 
19
  .DS_Store
20
  Thumbs.db
models/face_orientation_model/best_model.pt → MLP/models/meta_20260224_024200.npz RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:18c1f2750c7274e72538b94afcc9f0243287a5b2eb8fcce6be6e4ae18ec59cb0
3
- size 15033
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:769bb62c7bf04aafd808e9b2623e795c2d92bcb933313ebf553d6fce5ebe7143
3
+ size 1616
MLP/models/mlp_20260224_024200.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a72933fcf2d0aed998c6303ea4298c04618d937c7f17bf492e76efcf3b4b54d7
3
+ size 50484
MLP/models/scaler_20260224_024200.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3f9ef3721cee28f1472886556e001d0f6ed0abe09011d979a70ca9bf447d453e
3
+ size 823
README.md CHANGED
@@ -1,10 +1,10 @@
1
- # GAP — FocusGuard
2
 
3
- Real-time focus estimation from webcam (head pose + eye behaviour).
4
 
5
- ## Layout
 
 
 
6
 
7
- - **data_preparation/** Dataset team (raw data, processed, scripts)
8
- - **models/** — Face orientation, eye behaviour, fusion, landmarks. Training entry: `models/train.py`
9
- - **evaluation/** — Metrics, runs, results
10
- - **ui/** — Live demo + session view
 
1
+ # FocusGuard
2
 
3
+ Webcam-based focus detection: face mesh, head pose, eye (geometry or YOLO), plus an MLP trained on collected features.
4
 
5
+ - **data_preparation/** — collect data, notebooks, processed/collected files
6
+ - **models/** — face mesh, head pose, eye scorer, YOLO classifier, MLP training, attention feature collection
7
+ - **evaluation/** — metrics and run logs
8
+ - **ui/** — live demo (geometry+YOLO or MLP-only)
9
 
10
+ Run from here: `pip install -r requirements.txt` then `python ui/live_demo.py` or `python ui/live_demo.py --mlp`.
 
 
 
data_preparation/CNN/eye_crops/val/open/.gitkeep ADDED
@@ -0,0 +1 @@
 
 
1
+
data_preparation/MLP/explore_collected_data.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
data_preparation/MLP/train_mlp.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
data_preparation/README.md CHANGED
@@ -1,3 +1,41 @@
1
- # data_preparation
2
 
3
- Dataset team owns layout and scripts here.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Data Preparation
2
 
3
+ ## Folder Structure
4
+
5
+ ### collected/
6
+ Contains raw session files in `.npz` format.
7
+ Generated using:
8
+
9
+ python -m models.attention.collect_features
10
+
11
+ Each session includes:
12
+ - 17-dimensional feature vectors
13
+ - Corresponding labels
14
+
15
+ ---
16
+
17
+
18
+ ### MLP/
19
+ Contains notebooks for:
20
+ - Exploring collected data
21
+ - Training the sklearn MLP model (10 features)
22
+
23
+ Trained models are saved to:
24
+ ../MLP/models/
25
+
26
+ ---
27
+
28
+ ### CNN/
29
+ Eye crop directory structure for CNN training (YOLO).
30
+
31
+ ---
32
+
33
+ ## Collecting Data
34
+
35
+ **Step-by-step**
36
+
37
+ 1. From repo root Install deps: `pip install -r requirements.txt`.
38
+ 3. Run: `python -m models.attention.collect_features --name yourname`.
39
+ 4. Webcam opens. Look at the camera; press **1** when focused, **0** when unfocused. Switch every 10–30 sec so you get both labels.
40
+ 5. Press **p** to pause/resume.
41
+ 6. Press **q** when done. One `.npz` is saved to `data_preparation/collected/` (17 features + labels).
data_preparation/collected_Abdelrahman/abdelrahman_20260306_023035.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e2c48532150182c8933d4595e0a0711365645b699647e99976575b7c2adffaf8
3
+ size 1207980
data_preparation/collected_Ayten/ayten_session_1.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fbecdbffa1c1b03b3b0fb5f715dcb4ff885ecc67da4aff78e6952b8847a96014
3
+ size 1341056
data_preparation/collected_Jarek/Jarek_20260225_012931.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0fa68f4d587eee8d645b23b463a9f1c848b9bacc2adb68603d5fa9cd8cb744c7
3
+ size 1128864
data_preparation/collected_Junhao/Junhao_20260303_113554.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ec321ee79800c04fdc0f999690d07970445aeca61f977bf6537880bbc996b5e5
3
+ size 678336
data_preparation/collected_Kexin/kexin2_20260305_180229.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0e96fe17571fa1fcccc1b4bd0c8838270498883e4db6a608c4d4d4c3a8ac1d0d
3
+ size 1129700
data_preparation/collected_Kexin/kexin_20260224_151043.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8d402ca4e66910a2e174c4f4beec5d7b3db6a04213d29673b227ce6ef04b39c4
3
+ size 1329732
data_preparation/collected_Langyuan/Langyuan_20260303_153145.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5c679cdba334b2f3f0953b7e44f7209056277c826e2b7b5cfcf2b8b750898400
3
+ size 1198784
data_preparation/collected_Mohamed/session_20260224_010131.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0a784f703c13b83911f47ec507d32c25942a07572314b8a77cbf40ca8cdff16f
3
+ size 1006428
data_preparation/collected_Saba/saba_20260306_230710.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:db1cab5ddcf9988856c5bdca1183c8eba4647365e675a1d8a200d12f6b5d2097
3
+ size 663212
data_preparation/collected_Yingtao/Yingtao_20260306_023937.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7a75af17e25dca5f06ea9e7443ea5fee9db638f68a5910e014ee7cb8b7ae80fd
3
+ size 1338776
evaluation/README.md CHANGED
@@ -1,3 +1,3 @@
1
  # evaluation
2
 
3
- Metrics, experiment configs, and results live here.
 
1
  # evaluation
2
 
3
+ Place metrics scripts, run configs, and results here. Logs dir is used by `models.mlp.train` for training logs.
{models/attention_score_fusion → evaluation/logs}/.gitkeep RENAMED
File without changes
models/README.md CHANGED
@@ -1,8 +1,10 @@
1
  # models
2
 
3
- - `face_orientation_model/`S_face
4
- - `eye_behaviour_model/`S_eye
5
- - `attention_score_fusion/`fusion + smoothing
6
- - `face_landmarks_pretrained/` — MediaPipe FaceMesh (no training)
 
 
7
 
8
- `train.py` trains the MLP on feature vectors; `prepare_dataset.py` loads from `data_preparation/processed/` or synthetic.
 
1
  # models
2
 
3
+ - **cnn/eye_attention/**YOLO open/closed eye classifier, crop helper, train stub
4
+ - **mlp/**PyTorch MLP on feature vectors (face_orientation / eye_behaviour); checkpoints under `mlp/face_orientation_model/`, `mlp/eye_behaviour_model/`
5
+ - **geometric/face_orientation/**head pose (solvePnP). **geometric/eye_behaviour/** — EAR, gaze, MAR
6
+ - **pretrained/face_mesh/** — MediaPipe face landmarks (no training)
7
+ - **attention/** — webcam feature collection (17-d), stubs for train/classifier/fusion
8
+ - **prepare_dataset.py** — loads from `data_preparation/processed/` or synthetic; used by `mlp/train.py`
9
 
10
+ Run legacy MLP training: `python -m models.mlp.train`. The sklearn MLP used in the live demo is trained in `data_preparation/MLP/train_mlp.ipynb` and saved under `../MLP/models/`.
models/attention/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
models/{eye_behaviour_model/.gitkeep → attention/classifier.py} RENAMED
File without changes
models/attention/collect_features.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Usage: python -m models.attention.collect_features [--name alice] [--duration 600]
2
+
3
+ import argparse
4
+ import collections
5
+ import math
6
+ import os
7
+ import sys
8
+ import time
9
+
10
+ import cv2
11
+ import numpy as np
12
+
13
+ _PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
14
+ if _PROJECT_ROOT not in sys.path:
15
+ sys.path.insert(0, _PROJECT_ROOT)
16
+
17
+ from models.pretrained.face_mesh.face_mesh import FaceMeshDetector
18
+ from models.geometric.face_orientation.head_pose import HeadPoseEstimator
19
+ from models.geometric.eye_behaviour.eye_scorer import EyeBehaviourScorer, compute_gaze_ratio, compute_mar
20
+
21
+ FONT = cv2.FONT_HERSHEY_SIMPLEX
22
+ GREEN = (0, 255, 0)
23
+ RED = (0, 0, 255)
24
+ WHITE = (255, 255, 255)
25
+ YELLOW = (0, 255, 255)
26
+ ORANGE = (0, 165, 255)
27
+ GRAY = (120, 120, 120)
28
+
29
+ FEATURE_NAMES = [
30
+ "ear_left", "ear_right", "ear_avg", "h_gaze", "v_gaze", "mar",
31
+ "yaw", "pitch", "roll", "s_face", "s_eye", "gaze_offset", "head_deviation",
32
+ "perclos", "blink_rate", "closure_duration", "yawn_duration",
33
+ ]
34
+
35
+ NUM_FEATURES = len(FEATURE_NAMES)
36
+ assert NUM_FEATURES == 17
37
+
38
+
39
+ class TemporalTracker:
40
+ EAR_BLINK_THRESH = 0.21
41
+ MAR_YAWN_THRESH = 0.04
42
+ PERCLOS_WINDOW = 60
43
+ BLINK_WINDOW_SEC = 30.0
44
+
45
+ def __init__(self):
46
+ self.ear_history = collections.deque(maxlen=self.PERCLOS_WINDOW)
47
+ self.blink_timestamps = collections.deque()
48
+ self._eyes_closed = False
49
+ self._closure_start = None
50
+ self._yawn_start = None
51
+
52
+ def update(self, ear_avg, mar, now=None):
53
+ if now is None:
54
+ now = time.time()
55
+
56
+ closed = ear_avg < self.EAR_BLINK_THRESH
57
+ self.ear_history.append(1.0 if closed else 0.0)
58
+ perclos = sum(self.ear_history) / len(self.ear_history) if self.ear_history else 0.0
59
+
60
+ if self._eyes_closed and not closed:
61
+ self.blink_timestamps.append(now)
62
+ self._eyes_closed = closed
63
+
64
+ cutoff = now - self.BLINK_WINDOW_SEC
65
+ while self.blink_timestamps and self.blink_timestamps[0] < cutoff:
66
+ self.blink_timestamps.popleft()
67
+ blink_rate = len(self.blink_timestamps) * (60.0 / self.BLINK_WINDOW_SEC)
68
+
69
+ if closed:
70
+ if self._closure_start is None:
71
+ self._closure_start = now
72
+ closure_dur = now - self._closure_start
73
+ else:
74
+ self._closure_start = None
75
+ closure_dur = 0.0
76
+
77
+ yawning = mar > self.MAR_YAWN_THRESH
78
+ if yawning:
79
+ if self._yawn_start is None:
80
+ self._yawn_start = now
81
+ yawn_dur = now - self._yawn_start
82
+ else:
83
+ self._yawn_start = None
84
+ yawn_dur = 0.0
85
+
86
+ return perclos, blink_rate, closure_dur, yawn_dur
87
+
88
+
89
+ def extract_features(landmarks, w, h, head_pose, eye_scorer, temporal):
90
+ from models.geometric.eye_behaviour.eye_scorer import _LEFT_EYE_EAR, _RIGHT_EYE_EAR, compute_ear
91
+
92
+ ear_left = compute_ear(landmarks, _LEFT_EYE_EAR)
93
+ ear_right = compute_ear(landmarks, _RIGHT_EYE_EAR)
94
+ ear_avg = (ear_left + ear_right) / 2.0
95
+ h_gaze, v_gaze = compute_gaze_ratio(landmarks)
96
+ mar = compute_mar(landmarks)
97
+
98
+ angles = head_pose.estimate(landmarks, w, h)
99
+ yaw = angles[0] if angles else 0.0
100
+ pitch = angles[1] if angles else 0.0
101
+ roll = angles[2] if angles else 0.0
102
+
103
+ s_face = head_pose.score(landmarks, w, h)
104
+ s_eye = eye_scorer.score(landmarks)
105
+
106
+ gaze_offset = math.sqrt((h_gaze - 0.5) ** 2 + (v_gaze - 0.5) ** 2)
107
+ head_deviation = math.sqrt(yaw ** 2 + pitch ** 2)
108
+
109
+ perclos, blink_rate, closure_dur, yawn_dur = temporal.update(ear_avg, mar)
110
+
111
+ return np.array([
112
+ ear_left, ear_right, ear_avg,
113
+ h_gaze, v_gaze,
114
+ mar,
115
+ yaw, pitch, roll,
116
+ s_face, s_eye,
117
+ gaze_offset,
118
+ head_deviation,
119
+ perclos, blink_rate, closure_dur, yawn_dur,
120
+ ], dtype=np.float32)
121
+
122
+
123
+ def quality_report(labels):
124
+ n = len(labels)
125
+ n1 = int((labels == 1).sum())
126
+ n0 = n - n1
127
+ transitions = int(np.sum(np.diff(labels) != 0))
128
+ duration_sec = n / 30.0 # approximate at 30fps
129
+
130
+ warnings = []
131
+
132
+ print(f"\n{'='*50}")
133
+ print(f" DATA QUALITY REPORT")
134
+ print(f"{'='*50}")
135
+ print(f" Total samples : {n}")
136
+ print(f" Focused : {n1} ({n1/max(n,1)*100:.1f}%)")
137
+ print(f" Unfocused : {n0} ({n0/max(n,1)*100:.1f}%)")
138
+ print(f" Duration : {duration_sec:.0f}s ({duration_sec/60:.1f} min)")
139
+ print(f" Transitions : {transitions}")
140
+ if transitions > 0:
141
+ print(f" Avg segment : {n/transitions:.0f} frames ({n/transitions/30:.1f}s)")
142
+
143
+ # checks
144
+ if duration_sec < 120:
145
+ warnings.append(f"TOO SHORT: {duration_sec:.0f}s — aim for 5-10 minutes (300-600s)")
146
+
147
+ if n < 3000:
148
+ warnings.append(f"LOW SAMPLE COUNT: {n} frames — aim for 9000+ (5 min at 30fps)")
149
+
150
+ balance = n1 / max(n, 1)
151
+ if balance < 0.3 or balance > 0.7:
152
+ warnings.append(f"IMBALANCED: {balance:.0%} focused — aim for 35-65% focused")
153
+
154
+ if transitions < 10:
155
+ warnings.append(f"TOO FEW TRANSITIONS: {transitions} — switch every 10-30s, aim for 20+")
156
+
157
+ if transitions == 1:
158
+ warnings.append("SINGLE BLOCK: you recorded one unfocused + one focused block — "
159
+ "model will learn temporal position, not focus patterns")
160
+
161
+ if warnings:
162
+ print(f"\n ⚠️ WARNINGS ({len(warnings)}):")
163
+ for w in warnings:
164
+ print(f" • {w}")
165
+ print(f"\n Consider re-recording this session.")
166
+ else:
167
+ print(f"\n ✅ All checks passed!")
168
+
169
+ print(f"{'='*50}\n")
170
+ return len(warnings) == 0
171
+
172
+
173
+ # ---------------------------------------------------------------------------
174
+ # Main
175
+ def main():
176
+ parser = argparse.ArgumentParser()
177
+ parser.add_argument("--name", type=str, default="session",
178
+ help="Your name or session ID")
179
+ parser.add_argument("--camera", type=int, default=0,
180
+ help="Camera index")
181
+ parser.add_argument("--duration", type=int, default=600,
182
+ help="Max recording time (seconds, default 10 min)")
183
+ parser.add_argument("--output-dir", type=str,
184
+ default=os.path.join(_PROJECT_ROOT, "data_preparation", "collected"),
185
+ help="Where to save .npz files")
186
+ args = parser.parse_args()
187
+
188
+ os.makedirs(args.output_dir, exist_ok=True)
189
+
190
+ detector = FaceMeshDetector()
191
+ head_pose = HeadPoseEstimator()
192
+ eye_scorer = EyeBehaviourScorer()
193
+ temporal = TemporalTracker()
194
+
195
+ cap = cv2.VideoCapture(args.camera)
196
+ if not cap.isOpened():
197
+ print("[COLLECT] ERROR: can't open camera")
198
+ return
199
+
200
+ print("[COLLECT] Data Collection Tool")
201
+ print(f"[COLLECT] Session: {args.name}, max {args.duration}s")
202
+ print(f"[COLLECT] Features per frame: {NUM_FEATURES}")
203
+ print("[COLLECT] Controls:")
204
+ print(" 1 = FOCUSED (looking at screen normally)")
205
+ print(" 0 = NOT FOCUSED (phone, away, eyes closed, yawning)")
206
+ print(" p = pause")
207
+ print(" q = save & quit")
208
+ print()
209
+ print("[COLLECT] TIPS for good data:")
210
+ print(" • Switch between 1 and 0 every 10-30 seconds")
211
+ print(" • Aim for 20+ transitions total")
212
+ print(" • Act out varied scenarios: reading, phone, talking, drowsy")
213
+ print(" • Record at least 5 minutes")
214
+ print()
215
+
216
+ features_list = []
217
+ labels_list = []
218
+ label = None # None = paused
219
+ transitions = 0 # count label switches
220
+ prev_label = None
221
+ status = "PAUSED -- press 1 (focused) or 0 (not focused)"
222
+ t_start = time.time()
223
+ prev_time = time.time()
224
+ fps = 0.0
225
+
226
+ try:
227
+ while True:
228
+ elapsed = time.time() - t_start
229
+ if elapsed > args.duration:
230
+ print(f"[COLLECT] Time limit ({args.duration}s)")
231
+ break
232
+
233
+ ret, frame = cap.read()
234
+ if not ret:
235
+ break
236
+
237
+ h, w = frame.shape[:2]
238
+ landmarks = detector.process(frame)
239
+ face_ok = landmarks is not None
240
+
241
+ # record if labeling + face visible
242
+ if face_ok and label is not None:
243
+ vec = extract_features(landmarks, w, h, head_pose, eye_scorer, temporal)
244
+ features_list.append(vec)
245
+ labels_list.append(label)
246
+
247
+ # count transitions
248
+ if prev_label is not None and label != prev_label:
249
+ transitions += 1
250
+ prev_label = label
251
+
252
+ now = time.time()
253
+ fps = 0.9 * fps + 0.1 * (1.0 / max(now - prev_time, 1e-6))
254
+ prev_time = now
255
+
256
+ # --- draw UI ---
257
+ n = len(labels_list)
258
+ n1 = sum(1 for x in labels_list if x == 1)
259
+ n0 = n - n1
260
+ remaining = max(0, args.duration - elapsed)
261
+
262
+ bar_color = GREEN if label == 1 else (RED if label == 0 else (80, 80, 80))
263
+ cv2.rectangle(frame, (0, 0), (w, 70), (0, 0, 0), -1)
264
+ cv2.putText(frame, status, (10, 22), FONT, 0.55, bar_color, 2, cv2.LINE_AA)
265
+ cv2.putText(frame, f"Samples: {n} (F:{n1} U:{n0}) Switches: {transitions}",
266
+ (10, 48), FONT, 0.42, WHITE, 1, cv2.LINE_AA)
267
+ cv2.putText(frame, f"FPS:{fps:.0f}", (w - 80, 22), FONT, 0.45, WHITE, 1, cv2.LINE_AA)
268
+ cv2.putText(frame, f"{int(remaining)}s left", (w - 80, 48), FONT, 0.42, YELLOW, 1, cv2.LINE_AA)
269
+
270
+ if n > 0:
271
+ bar_w = min(w - 20, 300)
272
+ bar_x = w - bar_w - 10
273
+ bar_y = 58
274
+ frac = n1 / n
275
+ cv2.rectangle(frame, (bar_x, bar_y), (bar_x + bar_w, bar_y + 8), (40, 40, 40), -1)
276
+ cv2.rectangle(frame, (bar_x, bar_y), (bar_x + int(bar_w * frac), bar_y + 8), GREEN, -1)
277
+ cv2.putText(frame, f"{frac:.0%}F", (bar_x + bar_w + 4, bar_y + 8),
278
+ FONT, 0.3, GRAY, 1, cv2.LINE_AA)
279
+
280
+ if not face_ok:
281
+ cv2.putText(frame, "NO FACE", (w // 2 - 60, h // 2), FONT, 0.7, RED, 2, cv2.LINE_AA)
282
+
283
+ # red dot = recording
284
+ if label is not None and face_ok:
285
+ cv2.circle(frame, (w - 20, 80), 8, RED, -1)
286
+
287
+ # live warnings
288
+ warn_y = h - 35
289
+ if n > 100 and transitions < 3:
290
+ cv2.putText(frame, "! Switch more often (aim for 20+ transitions)",
291
+ (10, warn_y), FONT, 0.38, ORANGE, 1, cv2.LINE_AA)
292
+ warn_y -= 18
293
+ if elapsed > 30 and n > 0:
294
+ bal = n1 / n
295
+ if bal < 0.25 or bal > 0.75:
296
+ cv2.putText(frame, f"! Imbalanced ({bal:.0%} focused) - record more of the other",
297
+ (10, warn_y), FONT, 0.38, ORANGE, 1, cv2.LINE_AA)
298
+ warn_y -= 18
299
+
300
+ cv2.putText(frame, "1:focused 0:unfocused p:pause q:save+quit",
301
+ (10, h - 10), FONT, 0.38, GRAY, 1, cv2.LINE_AA)
302
+
303
+ cv2.imshow("FocusGuard -- Data Collection", frame)
304
+
305
+ key = cv2.waitKey(1) & 0xFF
306
+ if key == ord("1"):
307
+ label = 1
308
+ status = "Recording: FOCUSED"
309
+ print(f"[COLLECT] -> FOCUSED (n={n}, transitions={transitions})")
310
+ elif key == ord("0"):
311
+ label = 0
312
+ status = "Recording: NOT FOCUSED"
313
+ print(f"[COLLECT] -> NOT FOCUSED (n={n}, transitions={transitions})")
314
+ elif key == ord("p"):
315
+ label = None
316
+ status = "PAUSED"
317
+ print(f"[COLLECT] paused (n={n})")
318
+ elif key == ord("q"):
319
+ break
320
+
321
+ finally:
322
+ cap.release()
323
+ cv2.destroyAllWindows()
324
+ detector.close()
325
+
326
+ if len(features_list) > 0:
327
+ feats = np.stack(features_list)
328
+ labs = np.array(labels_list, dtype=np.int64)
329
+
330
+ ts = time.strftime("%Y%m%d_%H%M%S")
331
+ fname = f"{args.name}_{ts}.npz"
332
+ fpath = os.path.join(args.output_dir, fname)
333
+ np.savez(fpath,
334
+ features=feats,
335
+ labels=labs,
336
+ feature_names=np.array(FEATURE_NAMES))
337
+
338
+ print(f"\n[COLLECT] Saved {len(labs)} samples -> {fpath}")
339
+ print(f" Shape: {feats.shape} ({NUM_FEATURES} features)")
340
+
341
+ quality_report(labs)
342
+ else:
343
+ print("\n[COLLECT] No data collected")
344
+
345
+ print("[COLLECT] Done")
346
+
347
+
348
+ if __name__ == "__main__":
349
+ main()
models/{face_landmarks_pretrained/.gitkeep → attention/fusion.py} RENAMED
File without changes
models/{face_orientation_model/.gitkeep → attention/train.py} RENAMED
File without changes
models/cnn/CNN_MODEL/.claude/settings.local.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "permissions": {
3
+ "allow": [
4
+ "Bash(# Check Dataset_subset counts echo \"\"=== Dataset_subset/train/open ===\"\" && ls /Users/mohammedalketbi22/Downloads/GAP_Large_project-feature-dataset-model-test-92_30-clean/Dataset_subset/train/open/ | wc -l && echo \"\"=== Dataset_subset/train/closed ===\"\" && ls /Users/mohammedalketbi22/Downloads/GAP_Large_project-feature-dataset-model-test-92_30-clean/Dataset_subset/train/closed/ | wc -l && echo \"\"=== Dataset_subset/val/open ===\"\" && ls /Users/mohammedalketbi22/Downloads/GAP_Large_project-feature-dataset-model-test-92_30-clean/Dataset_subset/val/open/ | wc -l && echo \"\"=== Dataset_subset/val/closed ===\"\" && ls /Users/mohammedalketbi22/Downloads/GAP_Large_project-feature-dataset-model-test-92_30-clean/Dataset_subset/val/closed/)"
5
+ ]
6
+ }
7
+ }
models/cnn/CNN_MODEL/.gitattributes ADDED
@@ -0,0 +1 @@
 
 
1
+ DATA/** filter=lfs diff=lfs merge=lfs -text
models/cnn/CNN_MODEL/.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ Dataset/train/
2
+ Dataset/val/
3
+ Dataset/test/
4
+ .DS_Store
models/cnn/CNN_MODEL/README.md ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Eye Open / Closed Classifier (YOLOv11-CLS)
2
+
3
+
4
+ Binary classifier: **open** vs **closed** eyes.
5
+ Used as a baseline for eye-tracking, drowsiness, or focus detection.
6
+
7
+ ---
8
+
9
+ ## Model team task
10
+
11
+ - **Train** the YOLOv11s-cls eye classifier in a **separate notebook** (data split, epochs, GPU, export `best.pt`).
12
+ - Provide **trained weights** (`best.pt`) for this repo’s evaluation and inference scripts.
13
+
14
+
15
+
16
+ ---
17
+
18
+ ## Repo contents
19
+
20
+ - **notebooks/eye_classifier_colab.ipynb** — Data download (Kaggle), clean, split, undersample, **evaluate** (needs `best.pt` from model team), export.
21
+ - **scripts/predict_image.py** — Run classifier on single images (needs `best.pt`).
22
+ - **scripts/webcam_live.py** — Live webcam open/closed (needs `best.pt` + optional `weights/face_landmarker.task`).
23
+ - **scripts/video_infer.py** — Run on video files.
24
+ - **scripts/focus_infer.py** — Focus/attention inference.
25
+ - **weights/** — Put `best.pt` here; `face_landmarker.task` is downloaded on first webcam run if missing.
26
+ - **docs/** — Extra docs (e.g. UNNECESSARY_FILES.md if present).
27
+
28
+ ---
29
+
30
+ ## Dataset
31
+
32
+ - **Source:** [Kaggle — open/closed eyes](https://www.kaggle.com/datasets/sehriyarmemmedli/open-closed-eyes-dataset)
33
+ - The Colab notebook downloads it via `kagglehub`; no local copy in repo.
34
+
35
+ ---
36
+
37
+ ## Weights
38
+
39
+ - Put **best.pt** from the model team in **weights/best.pt** (or `runs/classify/runs_cls/eye_open_closed_cpu/weights/best.pt`).
40
+ - For webcam: **face_landmarker.task** is downloaded into **weights/** on first run if missing.
41
+
42
+ ---
43
+
44
+ ## Local setup
45
+
46
+ ```bash
47
+ pip install ultralytics opencv-python mediapipe "numpy<2"
48
+ ```
49
+
50
+ Optional: use a venv. From repo root:
51
+ - `python scripts/predict_image.py <image.png>`
52
+ - `python scripts/webcam_live.py`
53
+ - `python scripts/video_infer.py` (expects 1.mp4 / 2.mp4 in repo root or set `VIDEOS` env)
54
+ - `python scripts/focus_infer.py`
55
+
56
+ ---
57
+
58
+ ## Project structure
59
+
60
+ ```
61
+ ├── notebooks/
62
+ │ └── eye_classifier_colab.ipynb # Data + eval (no training)
63
+ ├── scripts/
64
+ │ ├── predict_image.py
65
+ │ ├── webcam_live.py
66
+ │ ├── video_infer.py
67
+ │ └── focus_infer.py
68
+ ├── weights/ # best.pt, face_landmarker.task
69
+ ├── docs/ # extra docs
70
+ ├── README.md
71
+ └── venv/ # optional
72
+ ```
73
+
74
+ Training and weight generation: **model team, separate notebook.**
models/cnn/CNN_MODEL/notebooks/eye_classifier_colab.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
models/cnn/CNN_MODEL/scripts/focus_infer.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+ import os
5
+
6
+ import cv2
7
+ import numpy as np
8
+ from ultralytics import YOLO
9
+
10
+
11
+ def list_images(folder: Path):
12
+ exts = {".png", ".jpg", ".jpeg", ".bmp", ".webp"}
13
+ return sorted([p for p in folder.iterdir() if p.suffix.lower() in exts])
14
+
15
+
16
+ def find_weights(project_root: Path) -> Path | None:
17
+ candidates = [
18
+ project_root / "weights" / "best.pt",
19
+ project_root / "runs" / "classify" / "runs_cls" / "eye_open_closed_cpu" / "weights" / "best.pt",
20
+ project_root / "runs" / "classify" / "runs_cls" / "eye_open_closed_cpu" / "weights" / "last.pt",
21
+ project_root / "runs_cls" / "eye_open_closed_cpu" / "weights" / "best.pt",
22
+ project_root / "runs_cls" / "eye_open_closed_cpu" / "weights" / "last.pt",
23
+ ]
24
+ return next((p for p in candidates if p.is_file()), None)
25
+
26
+
27
+ def detect_eyelid_boundary(gray: np.ndarray) -> np.ndarray | None:
28
+ """
29
+ Returns an ellipse fit to the largest contour near the eye boundary.
30
+ Output format: (center(x,y), (axis1, axis2), angle) or None.
31
+ """
32
+ blur = cv2.GaussianBlur(gray, (5, 5), 0)
33
+ edges = cv2.Canny(blur, 40, 120)
34
+ edges = cv2.dilate(edges, np.ones((3, 3), np.uint8), iterations=1)
35
+ contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
36
+ if not contours:
37
+ return None
38
+ contours = sorted(contours, key=cv2.contourArea, reverse=True)
39
+ for c in contours:
40
+ if len(c) >= 5 and cv2.contourArea(c) > 50:
41
+ return cv2.fitEllipse(c)
42
+ return None
43
+
44
+
45
+ def detect_pupil_center(gray: np.ndarray) -> tuple[int, int] | None:
46
+ """
47
+ More robust pupil detection:
48
+ - enhance contrast (CLAHE)
49
+ - find dark blobs
50
+ - score by circularity and proximity to center
51
+ """
52
+ h, w = gray.shape
53
+ clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
54
+ eq = clahe.apply(gray)
55
+ blur = cv2.GaussianBlur(eq, (7, 7), 0)
56
+
57
+ # Focus on the central region to avoid eyelashes/edges
58
+ cx, cy = w // 2, h // 2
59
+ rx, ry = int(w * 0.3), int(h * 0.3)
60
+ x0, x1 = max(cx - rx, 0), min(cx + rx, w)
61
+ y0, y1 = max(cy - ry, 0), min(cy + ry, h)
62
+ roi = blur[y0:y1, x0:x1]
63
+
64
+ # Inverted threshold to capture dark pupil
65
+ _, thresh = cv2.threshold(roi, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
66
+ thresh = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, np.ones((3, 3), np.uint8), iterations=2)
67
+ thresh = cv2.morphologyEx(thresh, cv2.MORPH_CLOSE, np.ones((5, 5), np.uint8), iterations=1)
68
+
69
+ contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
70
+ if not contours:
71
+ return None
72
+
73
+ best = None
74
+ best_score = -1.0
75
+ for c in contours:
76
+ area = cv2.contourArea(c)
77
+ if area < 15:
78
+ continue
79
+ perimeter = cv2.arcLength(c, True)
80
+ if perimeter <= 0:
81
+ continue
82
+ circularity = 4 * np.pi * (area / (perimeter * perimeter))
83
+ if circularity < 0.3:
84
+ continue
85
+ m = cv2.moments(c)
86
+ if m["m00"] == 0:
87
+ continue
88
+ px = int(m["m10"] / m["m00"]) + x0
89
+ py = int(m["m01"] / m["m00"]) + y0
90
+
91
+ # Score by circularity and distance to center
92
+ dist = np.hypot(px - cx, py - cy) / max(w, h)
93
+ score = circularity - dist
94
+ if score > best_score:
95
+ best_score = score
96
+ best = (px, py)
97
+
98
+ return best
99
+
100
+
101
+ def is_focused(pupil_center: tuple[int, int], img_shape: tuple[int, int]) -> bool:
102
+ """
103
+ Decide focus based on pupil offset from image center.
104
+ """
105
+ h, w = img_shape
106
+ cx, cy = w // 2, h // 2
107
+ px, py = pupil_center
108
+ dx = abs(px - cx) / max(w, 1)
109
+ dy = abs(py - cy) / max(h, 1)
110
+ return (dx < 0.12) and (dy < 0.12)
111
+
112
+
113
+ def annotate(img_bgr: np.ndarray, ellipse, pupil_center, focused: bool, cls_label: str, conf: float):
114
+ out = img_bgr.copy()
115
+ if ellipse is not None:
116
+ cv2.ellipse(out, ellipse, (0, 255, 255), 2)
117
+ if pupil_center is not None:
118
+ cv2.circle(out, pupil_center, 4, (0, 0, 255), -1)
119
+ label = f"{cls_label} ({conf:.2f}) | focused={int(focused)}"
120
+ cv2.putText(out, label, (8, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)
121
+ return out
122
+
123
+
124
+ def main():
125
+ project_root = Path(__file__).resolve().parent.parent
126
+ data_dir = project_root / "Dataset"
127
+ alt_data_dir = project_root / "DATA"
128
+ out_dir = project_root / "runs_focus"
129
+ out_dir.mkdir(parents=True, exist_ok=True)
130
+
131
+ weights = find_weights(project_root)
132
+ if weights is None:
133
+ print("Weights not found. Train first.")
134
+ return
135
+
136
+ # Support both Dataset/test/{open,closed} and Dataset/{open,closed}
137
+ def resolve_test_dirs(root: Path):
138
+ test_open = root / "test" / "open"
139
+ test_closed = root / "test" / "closed"
140
+ if test_open.exists() and test_closed.exists():
141
+ return test_open, test_closed
142
+ test_open = root / "open"
143
+ test_closed = root / "closed"
144
+ if test_open.exists() and test_closed.exists():
145
+ return test_open, test_closed
146
+ alt_closed = root / "close"
147
+ if test_open.exists() and alt_closed.exists():
148
+ return test_open, alt_closed
149
+ return None, None
150
+
151
+ test_open, test_closed = resolve_test_dirs(data_dir)
152
+ if (test_open is None or test_closed is None) and alt_data_dir.exists():
153
+ test_open, test_closed = resolve_test_dirs(alt_data_dir)
154
+
155
+ if not test_open.exists() or not test_closed.exists():
156
+ print("Test folders missing. Expected:")
157
+ print(test_open)
158
+ print(test_closed)
159
+ return
160
+
161
+ test_files = list_images(test_open) + list_images(test_closed)
162
+ print("Total test images:", len(test_files))
163
+ max_images = int(os.getenv("MAX_IMAGES", "0"))
164
+ if max_images > 0:
165
+ test_files = test_files[:max_images]
166
+ print("Limiting to MAX_IMAGES:", max_images)
167
+
168
+ model = YOLO(str(weights))
169
+ results = model.predict(test_files, imgsz=224, device="cpu", verbose=False)
170
+
171
+ names = model.names
172
+ for r in results:
173
+ probs = r.probs
174
+ top_idx = int(probs.top1)
175
+ top_conf = float(probs.top1conf)
176
+ pred_label = names[top_idx]
177
+
178
+ img = cv2.imread(r.path)
179
+ if img is None:
180
+ continue
181
+ gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
182
+
183
+ ellipse = detect_eyelid_boundary(gray)
184
+ pupil_center = detect_pupil_center(gray)
185
+ focused = False
186
+ if pred_label.lower() == "open" and pupil_center is not None:
187
+ focused = is_focused(pupil_center, gray.shape)
188
+
189
+ annotated = annotate(img, ellipse, pupil_center, focused, pred_label, top_conf)
190
+ out_path = out_dir / (Path(r.path).stem + "_annotated.jpg")
191
+ cv2.imwrite(str(out_path), annotated)
192
+
193
+ print(f"{Path(r.path).name}: pred={pred_label} conf={top_conf:.3f} focused={focused}")
194
+
195
+ print(f"\nAnnotated outputs saved to: {out_dir}")
196
+
197
+
198
+ if __name__ == "__main__":
199
+ main()
models/cnn/CNN_MODEL/scripts/predict_image.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Run the eye open/closed model on one or more images."""
2
+ import sys
3
+ from pathlib import Path
4
+
5
+ from ultralytics import YOLO
6
+
7
+
8
+ def main():
9
+ project_root = Path(__file__).resolve().parent.parent
10
+ weight_candidates = [
11
+ project_root / "weights" / "best.pt",
12
+ project_root / "runs" / "classify" / "runs_cls" / "eye_open_closed_cpu" / "weights" / "best.pt",
13
+ project_root / "runs" / "classify" / "runs_cls" / "eye_open_closed_cpu" / "weights" / "last.pt",
14
+ ]
15
+ weights = next((p for p in weight_candidates if p.is_file()), None)
16
+ if weights is None:
17
+ print("Weights not found. Put best.pt in weights/ or runs/.../weights/ (from model team).")
18
+ sys.exit(1)
19
+
20
+ if len(sys.argv) < 2:
21
+ print("Usage: python scripts/predict_image.py <image1> [image2 ...]")
22
+ print("Example: python scripts/predict_image.py path/to/image.png")
23
+ sys.exit(0)
24
+
25
+ model = YOLO(str(weights))
26
+ names = model.names
27
+
28
+ for path in sys.argv[1:]:
29
+ p = Path(path)
30
+ if not p.is_file():
31
+ print(p, "- file not found")
32
+ continue
33
+ try:
34
+ results = model.predict(str(p), imgsz=224, device="cpu", verbose=False)
35
+ except Exception as e:
36
+ print(p, "- error:", e)
37
+ continue
38
+ if not results:
39
+ print(p, "- no result")
40
+ continue
41
+ r = results[0]
42
+ top_idx = int(r.probs.top1)
43
+ conf = float(r.probs.top1conf)
44
+ label = names[top_idx]
45
+ print(f"{p.name}: {label} ({conf:.2%})")
46
+
47
+
48
+ if __name__ == "__main__":
49
+ main()
models/cnn/CNN_MODEL/scripts/video_infer.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from pathlib import Path
5
+
6
+ import cv2
7
+ import numpy as np
8
+ from ultralytics import YOLO
9
+
10
+ try:
11
+ import mediapipe as mp
12
+ except Exception: # pragma: no cover
13
+ mp = None
14
+
15
+
16
+ def find_weights(project_root: Path) -> Path | None:
17
+ candidates = [
18
+ project_root / "weights" / "best.pt",
19
+ project_root / "runs" / "classify" / "runs_cls" / "eye_open_closed_cpu" / "weights" / "best.pt",
20
+ project_root / "runs" / "classify" / "runs_cls" / "eye_open_closed_cpu" / "weights" / "last.pt",
21
+ project_root / "runs_cls" / "eye_open_closed_cpu" / "weights" / "best.pt",
22
+ project_root / "runs_cls" / "eye_open_closed_cpu" / "weights" / "last.pt",
23
+ ]
24
+ return next((p for p in candidates if p.is_file()), None)
25
+
26
+
27
+ def detect_pupil_center(gray: np.ndarray) -> tuple[int, int] | None:
28
+ h, w = gray.shape
29
+ clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
30
+ eq = clahe.apply(gray)
31
+ blur = cv2.GaussianBlur(eq, (7, 7), 0)
32
+
33
+ cx, cy = w // 2, h // 2
34
+ rx, ry = int(w * 0.3), int(h * 0.3)
35
+ x0, x1 = max(cx - rx, 0), min(cx + rx, w)
36
+ y0, y1 = max(cy - ry, 0), min(cy + ry, h)
37
+ roi = blur[y0:y1, x0:x1]
38
+
39
+ _, thresh = cv2.threshold(roi, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
40
+ thresh = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, np.ones((3, 3), np.uint8), iterations=2)
41
+ thresh = cv2.morphologyEx(thresh, cv2.MORPH_CLOSE, np.ones((5, 5), np.uint8), iterations=1)
42
+
43
+ contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
44
+ if not contours:
45
+ return None
46
+
47
+ best = None
48
+ best_score = -1.0
49
+ for c in contours:
50
+ area = cv2.contourArea(c)
51
+ if area < 15:
52
+ continue
53
+ perimeter = cv2.arcLength(c, True)
54
+ if perimeter <= 0:
55
+ continue
56
+ circularity = 4 * np.pi * (area / (perimeter * perimeter))
57
+ if circularity < 0.3:
58
+ continue
59
+ m = cv2.moments(c)
60
+ if m["m00"] == 0:
61
+ continue
62
+ px = int(m["m10"] / m["m00"]) + x0
63
+ py = int(m["m01"] / m["m00"]) + y0
64
+
65
+ dist = np.hypot(px - cx, py - cy) / max(w, h)
66
+ score = circularity - dist
67
+ if score > best_score:
68
+ best_score = score
69
+ best = (px, py)
70
+
71
+ return best
72
+
73
+
74
+ def is_focused(pupil_center: tuple[int, int], img_shape: tuple[int, int]) -> bool:
75
+ h, w = img_shape
76
+ cx = w // 2
77
+ px, _ = pupil_center
78
+ dx = abs(px - cx) / max(w, 1)
79
+ return dx < 0.12
80
+
81
+
82
+ def classify_frame(model: YOLO, frame: np.ndarray) -> tuple[str, float]:
83
+ # Use classifier directly on frame (assumes frame is eye crop)
84
+ results = model.predict(frame, imgsz=224, device="cpu", verbose=False)
85
+ r = results[0]
86
+ probs = r.probs
87
+ top_idx = int(probs.top1)
88
+ top_conf = float(probs.top1conf)
89
+ pred_label = model.names[top_idx]
90
+ return pred_label, top_conf
91
+
92
+
93
+ def annotate_frame(frame: np.ndarray, label: str, focused: bool, conf: float, time_sec: float):
94
+ out = frame.copy()
95
+ text = f"{label} | focused={int(focused)} | conf={conf:.2f} | t={time_sec:.2f}s"
96
+ cv2.putText(out, text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 0), 2)
97
+ return out
98
+
99
+
100
+ def write_segments(path: Path, segments: list[tuple[float, float, str]]):
101
+ with path.open("w") as f:
102
+ for start, end, label in segments:
103
+ f.write(f"{start:.2f},{end:.2f},{label}\n")
104
+
105
+
106
+ def process_video(video_path: Path, model: YOLO | None):
107
+ cap = cv2.VideoCapture(str(video_path))
108
+ if not cap.isOpened():
109
+ print(f"Failed to open {video_path}")
110
+ return
111
+
112
+ fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
113
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
114
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
115
+
116
+ out_path = video_path.with_name(video_path.stem + "_pred.mp4")
117
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
118
+ writer = cv2.VideoWriter(str(out_path), fourcc, fps, (width, height))
119
+
120
+ csv_path = video_path.with_name(video_path.stem + "_predictions.csv")
121
+ seg_path = video_path.with_name(video_path.stem + "_segments.txt")
122
+
123
+ frame_idx = 0
124
+ last_label = None
125
+ seg_start = 0.0
126
+ segments: list[tuple[float, float, str]] = []
127
+
128
+ with csv_path.open("w") as fcsv:
129
+ fcsv.write("time_sec,label,focused,conf\n")
130
+ if mp is None:
131
+ print("mediapipe is not installed. Falling back to classifier-only mode.")
132
+ use_mp = mp is not None
133
+ if use_mp:
134
+ mp_face_mesh = mp.solutions.face_mesh
135
+ face_mesh = mp_face_mesh.FaceMesh(
136
+ static_image_mode=False,
137
+ max_num_faces=1,
138
+ refine_landmarks=True,
139
+ min_detection_confidence=0.5,
140
+ min_tracking_confidence=0.5,
141
+ )
142
+
143
+ while True:
144
+ ret, frame = cap.read()
145
+ if not ret:
146
+ break
147
+ time_sec = frame_idx / fps
148
+ conf = 0.0
149
+ pred_label = "open"
150
+ focused = False
151
+
152
+ if use_mp:
153
+ rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
154
+ res = face_mesh.process(rgb)
155
+ if res.multi_face_landmarks:
156
+ lm = res.multi_face_landmarks[0].landmark
157
+ h, w = frame.shape[:2]
158
+
159
+ # Eye landmarks (MediaPipe FaceMesh)
160
+ left_eye = [33, 160, 158, 133, 153, 144]
161
+ right_eye = [362, 385, 387, 263, 373, 380]
162
+ left_iris = [468, 469, 470, 471]
163
+ right_iris = [473, 474, 475, 476]
164
+
165
+ def pts(idxs):
166
+ return np.array([(int(lm[i].x * w), int(lm[i].y * h)) for i in idxs])
167
+
168
+ def ear(eye_pts):
169
+ # EAR using 6 points
170
+ p1, p2, p3, p4, p5, p6 = eye_pts
171
+ v1 = np.linalg.norm(p2 - p6)
172
+ v2 = np.linalg.norm(p3 - p5)
173
+ h1 = np.linalg.norm(p1 - p4)
174
+ return (v1 + v2) / (2.0 * h1 + 1e-6)
175
+
176
+ le = pts(left_eye)
177
+ re = pts(right_eye)
178
+ le_ear = ear(le)
179
+ re_ear = ear(re)
180
+ ear_avg = (le_ear + re_ear) / 2.0
181
+
182
+ # openness threshold
183
+ pred_label = "open" if ear_avg > 0.22 else "closed"
184
+
185
+ # iris centers
186
+ li = pts(left_iris)
187
+ ri = pts(right_iris)
188
+ li_c = li.mean(axis=0).astype(int)
189
+ ri_c = ri.mean(axis=0).astype(int)
190
+
191
+ # eye centers (midpoint of corners)
192
+ le_c = ((le[0] + le[3]) / 2).astype(int)
193
+ re_c = ((re[0] + re[3]) / 2).astype(int)
194
+
195
+ # focus = iris close to eye center horizontally for both eyes
196
+ le_dx = abs(li_c[0] - le_c[0]) / max(np.linalg.norm(le[0] - le[3]), 1)
197
+ re_dx = abs(ri_c[0] - re_c[0]) / max(np.linalg.norm(re[0] - re[3]), 1)
198
+ focused = (pred_label == "open") and (le_dx < 0.18) and (re_dx < 0.18)
199
+
200
+ # draw eye boundaries
201
+ cv2.polylines(frame, [le], True, (0, 255, 255), 1)
202
+ cv2.polylines(frame, [re], True, (0, 255, 255), 1)
203
+ # draw iris centers
204
+ cv2.circle(frame, tuple(li_c), 3, (0, 0, 255), -1)
205
+ cv2.circle(frame, tuple(ri_c), 3, (0, 0, 255), -1)
206
+ else:
207
+ pred_label = "closed"
208
+ focused = False
209
+ else:
210
+ if model is not None:
211
+ pred_label, conf = classify_frame(model, frame)
212
+ gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
213
+ pupil_center = detect_pupil_center(gray) if pred_label.lower() == "open" else None
214
+ focused = False
215
+ if pred_label.lower() == "open" and pupil_center is not None:
216
+ focused = is_focused(pupil_center, gray.shape)
217
+
218
+ if pred_label.lower() != "open":
219
+ focused = False
220
+
221
+ label = "open_focused" if (pred_label.lower() == "open" and focused) else "open_not_focused"
222
+ if pred_label.lower() != "open":
223
+ label = "closed_not_focused"
224
+
225
+ fcsv.write(f"{time_sec:.2f},{label},{int(focused)},{conf:.4f}\n")
226
+
227
+ if last_label is None:
228
+ last_label = label
229
+ seg_start = time_sec
230
+ elif label != last_label:
231
+ segments.append((seg_start, time_sec, last_label))
232
+ seg_start = time_sec
233
+ last_label = label
234
+
235
+ annotated = annotate_frame(frame, label, focused, conf, time_sec)
236
+ writer.write(annotated)
237
+ frame_idx += 1
238
+
239
+ if last_label is not None:
240
+ end_time = frame_idx / fps
241
+ segments.append((seg_start, end_time, last_label))
242
+ write_segments(seg_path, segments)
243
+
244
+ cap.release()
245
+ writer.release()
246
+ print(f"Saved: {out_path}")
247
+ print(f"CSV: {csv_path}")
248
+ print(f"Segments: {seg_path}")
249
+
250
+
251
+ def main():
252
+ project_root = Path(__file__).resolve().parent.parent
253
+ weights = find_weights(project_root)
254
+ model = YOLO(str(weights)) if weights is not None else None
255
+
256
+ # Default to 1.mp4 and 2.mp4 in project root
257
+ videos = []
258
+ for name in ["1.mp4", "2.mp4"]:
259
+ p = project_root / name
260
+ if p.exists():
261
+ videos.append(p)
262
+
263
+ # Also allow passing paths via env var
264
+ extra = os.getenv("VIDEOS", "")
265
+ for v in [x.strip() for x in extra.split(",") if x.strip()]:
266
+ vp = Path(v)
267
+ if not vp.is_absolute():
268
+ vp = project_root / vp
269
+ if vp.exists():
270
+ videos.append(vp)
271
+
272
+ if not videos:
273
+ print("No videos found. Expected 1.mp4 / 2.mp4 in project root.")
274
+ return
275
+
276
+ for v in videos:
277
+ process_video(v, model)
278
+
279
+
280
+ if __name__ == "__main__":
281
+ main()
models/cnn/CNN_MODEL/scripts/webcam_live.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Live webcam: detect face, crop each eye, run open/closed classifier, show on screen.
3
+ Requires: opencv-python, ultralytics, mediapipe (pip install mediapipe).
4
+ Press 'q' to quit.
5
+ """
6
+ import urllib.request
7
+ from pathlib import Path
8
+
9
+ import cv2
10
+ import numpy as np
11
+ from ultralytics import YOLO
12
+
13
+ try:
14
+ import mediapipe as mp
15
+ _mp_has_solutions = hasattr(mp, "solutions")
16
+ except ImportError:
17
+ mp = None
18
+ _mp_has_solutions = False
19
+
20
+ # New MediaPipe Tasks API (Face Landmarker) eye indices
21
+ LEFT_EYE_INDICES_NEW = [263, 249, 390, 373, 374, 380, 381, 382, 362, 466, 388, 387, 386, 385, 384, 398]
22
+ RIGHT_EYE_INDICES_NEW = [33, 7, 163, 144, 145, 153, 154, 155, 133, 246, 161, 160, 159, 158, 157, 173]
23
+ # Old Face Mesh (solutions) indices
24
+ LEFT_EYE_INDICES_OLD = [33, 160, 158, 133, 153, 144]
25
+ RIGHT_EYE_INDICES_OLD = [362, 385, 387, 263, 373, 380]
26
+ EYE_PADDING = 0.35
27
+
28
+
29
+ def find_weights(project_root: Path) -> Path | None:
30
+ candidates = [
31
+ project_root / "weights" / "best.pt",
32
+ project_root / "runs" / "classify" / "runs_cls" / "eye_open_closed_cpu" / "weights" / "best.pt",
33
+ project_root / "runs" / "classify" / "runs_cls" / "eye_open_closed_cpu" / "weights" / "last.pt",
34
+ ]
35
+ return next((p for p in candidates if p.is_file()), None)
36
+
37
+
38
+ def get_eye_roi(frame: np.ndarray, landmarks, indices: list[int]) -> np.ndarray | None:
39
+ h, w = frame.shape[:2]
40
+ pts = np.array([(int(landmarks[i].x * w), int(landmarks[i].y * h)) for i in indices])
41
+ x_min, y_min = pts.min(axis=0)
42
+ x_max, y_max = pts.max(axis=0)
43
+ dx = max(int((x_max - x_min) * EYE_PADDING), 8)
44
+ dy = max(int((y_max - y_min) * EYE_PADDING), 8)
45
+ x0 = max(0, x_min - dx)
46
+ y0 = max(0, y_min - dy)
47
+ x1 = min(w, x_max + dx)
48
+ y1 = min(h, y_max + dy)
49
+ if x1 <= x0 or y1 <= y0:
50
+ return None
51
+ return frame[y0:y1, x0:x1].copy()
52
+
53
+
54
+ def _run_with_solutions(mp, model, cap):
55
+ face_mesh = mp.solutions.face_mesh.FaceMesh(
56
+ static_image_mode=False,
57
+ max_num_faces=1,
58
+ refine_landmarks=True,
59
+ min_detection_confidence=0.5,
60
+ min_tracking_confidence=0.5,
61
+ )
62
+ while True:
63
+ ret, frame = cap.read()
64
+ if not ret:
65
+ break
66
+ rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
67
+ results = face_mesh.process(rgb)
68
+ left_label, left_conf = "—", 0.0
69
+ right_label, right_conf = "—", 0.0
70
+ if results.multi_face_landmarks:
71
+ lm = results.multi_face_landmarks[0].landmark
72
+ for roi, indices, side in [
73
+ (get_eye_roi(frame, lm, LEFT_EYE_INDICES_OLD), LEFT_EYE_INDICES_OLD, "left"),
74
+ (get_eye_roi(frame, lm, RIGHT_EYE_INDICES_OLD), RIGHT_EYE_INDICES_OLD, "right"),
75
+ ]:
76
+ if roi is not None and roi.size > 0:
77
+ try:
78
+ pred = model.predict(roi, imgsz=224, device="cpu", verbose=False)
79
+ if pred:
80
+ r = pred[0]
81
+ label = model.names[int(r.probs.top1)]
82
+ conf = float(r.probs.top1conf)
83
+ if side == "left":
84
+ left_label, left_conf = label, conf
85
+ else:
86
+ right_label, right_conf = label, conf
87
+ except Exception:
88
+ pass
89
+ cv2.putText(frame, f"L: {left_label} ({left_conf:.0%})", (20, 40), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
90
+ cv2.putText(frame, f"R: {right_label} ({right_conf:.0%})", (20, 80), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
91
+ cv2.imshow("Eye open/closed (q to quit)", frame)
92
+ if cv2.waitKey(1) & 0xFF == ord("q"):
93
+ break
94
+
95
+
96
+ def _run_with_tasks(project_root: Path, model, cap):
97
+ from mediapipe.tasks.python import BaseOptions
98
+ from mediapipe.tasks.python.vision import FaceLandmarker, FaceLandmarkerOptions
99
+ from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode
100
+ from mediapipe.tasks.python.vision.core import image as image_lib
101
+
102
+ model_path = project_root / "weights" / "face_landmarker.task"
103
+ if not model_path.is_file():
104
+ print("Downloading face_landmarker.task ...")
105
+ url = "https://storage.googleapis.com/mediapipe-models/face_landmarker/face_landmarker/float16/1/face_landmarker.task"
106
+ urllib.request.urlretrieve(url, model_path)
107
+ print("Done.")
108
+
109
+ options = FaceLandmarkerOptions(
110
+ base_options=BaseOptions(model_asset_path=str(model_path)),
111
+ running_mode=running_mode.VisionTaskRunningMode.IMAGE,
112
+ num_faces=1,
113
+ )
114
+ face_landmarker = FaceLandmarker.create_from_options(options)
115
+ ImageFormat = image_lib.ImageFormat
116
+
117
+ while True:
118
+ ret, frame = cap.read()
119
+ if not ret:
120
+ break
121
+ left_label, left_conf = "—", 0.0
122
+ right_label, right_conf = "—", 0.0
123
+
124
+ rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
125
+ rgb_contiguous = np.ascontiguousarray(rgb)
126
+ mp_image = image_lib.Image(ImageFormat.SRGB, rgb_contiguous)
127
+ result = face_landmarker.detect(mp_image)
128
+
129
+ if result.face_landmarks:
130
+ lm = result.face_landmarks[0]
131
+ for roi, side in [
132
+ (get_eye_roi(frame, lm, LEFT_EYE_INDICES_NEW), "left"),
133
+ (get_eye_roi(frame, lm, RIGHT_EYE_INDICES_NEW), "right"),
134
+ ]:
135
+ if roi is not None and roi.size > 0:
136
+ try:
137
+ pred = model.predict(roi, imgsz=224, device="cpu", verbose=False)
138
+ if pred:
139
+ r = pred[0]
140
+ label = model.names[int(r.probs.top1)]
141
+ conf = float(r.probs.top1conf)
142
+ if side == "left":
143
+ left_label, left_conf = label, conf
144
+ else:
145
+ right_label, right_conf = label, conf
146
+ except Exception:
147
+ pass
148
+
149
+ cv2.putText(frame, f"L: {left_label} ({left_conf:.0%})", (20, 40), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
150
+ cv2.putText(frame, f"R: {right_label} ({right_conf:.0%})", (20, 80), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
151
+ cv2.imshow("Eye open/closed (q to quit)", frame)
152
+ if cv2.waitKey(1) & 0xFF == ord("q"):
153
+ break
154
+
155
+
156
+ def main():
157
+ project_root = Path(__file__).resolve().parent.parent
158
+ weights = find_weights(project_root)
159
+ if weights is None:
160
+ print("Weights not found. Put best.pt in weights/ or runs/.../weights/ (from model team).")
161
+ return
162
+ if mp is None:
163
+ print("MediaPipe required. Install: pip install mediapipe")
164
+ return
165
+
166
+ model = YOLO(str(weights))
167
+ cap = cv2.VideoCapture(0)
168
+ if not cap.isOpened():
169
+ print("Could not open webcam.")
170
+ return
171
+
172
+ print("Live eye open/closed on your face. Press 'q' to quit.")
173
+ try:
174
+ if _mp_has_solutions:
175
+ _run_with_solutions(mp, model, cap)
176
+ else:
177
+ _run_with_tasks(project_root, model, cap)
178
+ finally:
179
+ cap.release()
180
+ cv2.destroyAllWindows()
181
+
182
+
183
+ if __name__ == "__main__":
184
+ main()
models/cnn/CNN_MODEL/weights/yolo11s-cls.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e2b605d1c8c212b434a75a32759a6f7adf1d2b29c35f76bdccd4c794cb653cf2
3
+ size 13630112
models/cnn/__init__.py ADDED
File without changes
models/cnn/eye_attention/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
models/cnn/eye_attention/classifier.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from abc import ABC, abstractmethod
4
+
5
+ import numpy as np
6
+
7
+
8
+ class EyeClassifier(ABC):
9
+ @property
10
+ @abstractmethod
11
+ def name(self) -> str:
12
+ pass
13
+
14
+ @abstractmethod
15
+ def predict_score(self, crops_bgr: list[np.ndarray]) -> float:
16
+ pass
17
+
18
+
19
+ class GeometricOnlyClassifier(EyeClassifier):
20
+ @property
21
+ def name(self) -> str:
22
+ return "geometric"
23
+
24
+ def predict_score(self, crops_bgr: list[np.ndarray]) -> float:
25
+ return 1.0
26
+
27
+
28
+ class YOLOv11Classifier(EyeClassifier):
29
+ def __init__(self, checkpoint_path: str, device: str = "cpu"):
30
+ from ultralytics import YOLO
31
+
32
+ self._model = YOLO(checkpoint_path)
33
+ self._device = device
34
+
35
+ names = self._model.names
36
+ self._attentive_idx = None
37
+ for idx, cls_name in names.items():
38
+ if cls_name in ("open", "attentive"):
39
+ self._attentive_idx = idx
40
+ break
41
+ if self._attentive_idx is None:
42
+ self._attentive_idx = max(names.keys())
43
+ print(f"[YOLO] Classes: {names}, attentive_idx={self._attentive_idx}")
44
+
45
+ @property
46
+ def name(self) -> str:
47
+ return "yolo"
48
+
49
+ def predict_score(self, crops_bgr: list[np.ndarray]) -> float:
50
+ if not crops_bgr:
51
+ return 1.0
52
+ results = self._model.predict(crops_bgr, device=self._device, verbose=False)
53
+ scores = [float(r.probs.data[self._attentive_idx]) for r in results]
54
+ return sum(scores) / len(scores) if scores else 1.0
55
+
56
+
57
+ def load_eye_classifier(
58
+ path: str | None = None,
59
+ backend: str = "yolo",
60
+ device: str = "cpu",
61
+ ) -> EyeClassifier:
62
+ if path is None or backend == "geometric":
63
+ return GeometricOnlyClassifier()
64
+
65
+ try:
66
+ return YOLOv11Classifier(path, device=device)
67
+ except ImportError:
68
+ print("[CLASSIFIER] ultralytics required for YOLO. pip install ultralytics")
69
+ raise
models/cnn/eye_attention/crop.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+
4
+ from models.pretrained.face_mesh.face_mesh import FaceMeshDetector
5
+
6
+ LEFT_EYE_CONTOUR = FaceMeshDetector.LEFT_EYE_INDICES
7
+ RIGHT_EYE_CONTOUR = FaceMeshDetector.RIGHT_EYE_INDICES
8
+
9
+ IMAGENET_MEAN = (0.485, 0.456, 0.406)
10
+ IMAGENET_STD = (0.229, 0.224, 0.225)
11
+
12
+ CROP_SIZE = 96
13
+
14
+
15
+ def _bbox_from_landmarks(
16
+ landmarks: np.ndarray,
17
+ indices: list[int],
18
+ frame_w: int,
19
+ frame_h: int,
20
+ expand: float = 0.4,
21
+ ) -> tuple[int, int, int, int]:
22
+ pts = landmarks[indices, :2]
23
+ px = pts[:, 0] * frame_w
24
+ py = pts[:, 1] * frame_h
25
+
26
+ x_min, x_max = px.min(), px.max()
27
+ y_min, y_max = py.min(), py.max()
28
+ w = x_max - x_min
29
+ h = y_max - y_min
30
+ cx = (x_min + x_max) / 2
31
+ cy = (y_min + y_max) / 2
32
+
33
+ size = max(w, h) * (1 + expand)
34
+ half = size / 2
35
+
36
+ x1 = int(max(cx - half, 0))
37
+ y1 = int(max(cy - half, 0))
38
+ x2 = int(min(cx + half, frame_w))
39
+ y2 = int(min(cy + half, frame_h))
40
+
41
+ return x1, y1, x2, y2
42
+
43
+
44
+ def extract_eye_crops(
45
+ frame: np.ndarray,
46
+ landmarks: np.ndarray,
47
+ expand: float = 0.4,
48
+ crop_size: int = CROP_SIZE,
49
+ ) -> tuple[np.ndarray, np.ndarray, tuple, tuple]:
50
+ h, w = frame.shape[:2]
51
+
52
+ left_bbox = _bbox_from_landmarks(landmarks, LEFT_EYE_CONTOUR, w, h, expand)
53
+ right_bbox = _bbox_from_landmarks(landmarks, RIGHT_EYE_CONTOUR, w, h, expand)
54
+
55
+ left_crop = frame[left_bbox[1] : left_bbox[3], left_bbox[0] : left_bbox[2]]
56
+ right_crop = frame[right_bbox[1] : right_bbox[3], right_bbox[0] : right_bbox[2]]
57
+
58
+ left_crop = cv2.resize(left_crop, (crop_size, crop_size), interpolation=cv2.INTER_AREA)
59
+ right_crop = cv2.resize(right_crop, (crop_size, crop_size), interpolation=cv2.INTER_AREA)
60
+
61
+ return left_crop, right_crop, left_bbox, right_bbox
62
+
63
+
64
+ def crop_to_tensor(crop_bgr: np.ndarray):
65
+ import torch
66
+
67
+ rgb = cv2.cvtColor(crop_bgr, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
68
+ for c in range(3):
69
+ rgb[:, :, c] = (rgb[:, :, c] - IMAGENET_MEAN[c]) / IMAGENET_STD[c]
70
+ return torch.from_numpy(rgb.transpose(2, 0, 1))
models/cnn/eye_attention/train.py ADDED
File without changes
models/geometric/__init__.py ADDED
File without changes
models/geometric/eye_behaviour/__init__.py ADDED
File without changes
models/geometric/eye_behaviour/eye_scorer.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import numpy as np
4
+
5
+ _LEFT_EYE_EAR = [33, 160, 158, 133, 153, 145]
6
+ _RIGHT_EYE_EAR = [362, 385, 387, 263, 373, 380]
7
+
8
+ _LEFT_IRIS_CENTER = 468
9
+ _RIGHT_IRIS_CENTER = 473
10
+
11
+ _LEFT_EYE_INNER = 133
12
+ _LEFT_EYE_OUTER = 33
13
+ _RIGHT_EYE_INNER = 362
14
+ _RIGHT_EYE_OUTER = 263
15
+
16
+ _LEFT_EYE_TOP = 159
17
+ _LEFT_EYE_BOTTOM = 145
18
+ _RIGHT_EYE_TOP = 386
19
+ _RIGHT_EYE_BOTTOM = 374
20
+
21
+ _MOUTH_TOP = 13
22
+ _MOUTH_BOTTOM = 14
23
+ _MOUTH_LEFT = 78
24
+ _MOUTH_RIGHT = 308
25
+ _MOUTH_UPPER_1 = 82
26
+ _MOUTH_UPPER_2 = 312
27
+ _MOUTH_LOWER_1 = 87
28
+ _MOUTH_LOWER_2 = 317
29
+
30
+ MAR_YAWN_THRESHOLD = 0.55
31
+
32
+
33
+ def _distance(p1: np.ndarray, p2: np.ndarray) -> float:
34
+ return float(np.linalg.norm(p1 - p2))
35
+
36
+
37
+ def compute_ear(landmarks: np.ndarray, eye_indices: list[int]) -> float:
38
+ p1 = landmarks[eye_indices[0], :2]
39
+ p2 = landmarks[eye_indices[1], :2]
40
+ p3 = landmarks[eye_indices[2], :2]
41
+ p4 = landmarks[eye_indices[3], :2]
42
+ p5 = landmarks[eye_indices[4], :2]
43
+ p6 = landmarks[eye_indices[5], :2]
44
+
45
+ vertical1 = _distance(p2, p6)
46
+ vertical2 = _distance(p3, p5)
47
+ horizontal = _distance(p1, p4)
48
+
49
+ if horizontal < 1e-6:
50
+ return 0.0
51
+
52
+ return (vertical1 + vertical2) / (2.0 * horizontal)
53
+
54
+
55
+ def compute_avg_ear(landmarks: np.ndarray) -> float:
56
+ left_ear = compute_ear(landmarks, _LEFT_EYE_EAR)
57
+ right_ear = compute_ear(landmarks, _RIGHT_EYE_EAR)
58
+ return (left_ear + right_ear) / 2.0
59
+
60
+
61
+ def compute_gaze_ratio(landmarks: np.ndarray) -> tuple[float, float]:
62
+ left_iris = landmarks[_LEFT_IRIS_CENTER, :2]
63
+ left_inner = landmarks[_LEFT_EYE_INNER, :2]
64
+ left_outer = landmarks[_LEFT_EYE_OUTER, :2]
65
+ left_top = landmarks[_LEFT_EYE_TOP, :2]
66
+ left_bottom = landmarks[_LEFT_EYE_BOTTOM, :2]
67
+
68
+ right_iris = landmarks[_RIGHT_IRIS_CENTER, :2]
69
+ right_inner = landmarks[_RIGHT_EYE_INNER, :2]
70
+ right_outer = landmarks[_RIGHT_EYE_OUTER, :2]
71
+ right_top = landmarks[_RIGHT_EYE_TOP, :2]
72
+ right_bottom = landmarks[_RIGHT_EYE_BOTTOM, :2]
73
+
74
+ left_h_total = _distance(left_inner, left_outer)
75
+ right_h_total = _distance(right_inner, right_outer)
76
+
77
+ if left_h_total < 1e-6 or right_h_total < 1e-6:
78
+ return 0.5, 0.5
79
+
80
+ left_h_ratio = _distance(left_outer, left_iris) / left_h_total
81
+ right_h_ratio = _distance(right_outer, right_iris) / right_h_total
82
+ h_ratio = (left_h_ratio + right_h_ratio) / 2.0
83
+
84
+ left_v_total = _distance(left_top, left_bottom)
85
+ right_v_total = _distance(right_top, right_bottom)
86
+
87
+ if left_v_total < 1e-6 or right_v_total < 1e-6:
88
+ return h_ratio, 0.5
89
+
90
+ left_v_ratio = _distance(left_top, left_iris) / left_v_total
91
+ right_v_ratio = _distance(right_top, right_iris) / right_v_total
92
+ v_ratio = (left_v_ratio + right_v_ratio) / 2.0
93
+
94
+ return float(np.clip(h_ratio, 0, 1)), float(np.clip(v_ratio, 0, 1))
95
+
96
+
97
+ def compute_mar(landmarks: np.ndarray) -> float:
98
+ # Mouth aspect ratio: high = mouth open (yawning / sleepy)
99
+ top = landmarks[_MOUTH_TOP, :2]
100
+ bottom = landmarks[_MOUTH_BOTTOM, :2]
101
+ left = landmarks[_MOUTH_LEFT, :2]
102
+ right = landmarks[_MOUTH_RIGHT, :2]
103
+ upper1 = landmarks[_MOUTH_UPPER_1, :2]
104
+ lower1 = landmarks[_MOUTH_LOWER_1, :2]
105
+ upper2 = landmarks[_MOUTH_UPPER_2, :2]
106
+ lower2 = landmarks[_MOUTH_LOWER_2, :2]
107
+
108
+ horizontal = _distance(left, right)
109
+ if horizontal < 1e-6:
110
+ return 0.0
111
+ v1 = _distance(upper1, lower1)
112
+ v2 = _distance(top, bottom)
113
+ v3 = _distance(upper2, lower2)
114
+ return (v1 + v2 + v3) / (2.0 * horizontal)
115
+
116
+
117
+ class EyeBehaviourScorer:
118
+ def __init__(
119
+ self,
120
+ ear_open: float = 0.30,
121
+ ear_closed: float = 0.16,
122
+ gaze_max_offset: float = 0.28,
123
+ ):
124
+ self.ear_open = ear_open
125
+ self.ear_closed = ear_closed
126
+ self.gaze_max_offset = gaze_max_offset
127
+
128
+ def _ear_score(self, ear: float) -> float:
129
+ if ear >= self.ear_open:
130
+ return 1.0
131
+ if ear <= self.ear_closed:
132
+ return 0.0
133
+ return (ear - self.ear_closed) / (self.ear_open - self.ear_closed)
134
+
135
+ def _gaze_score(self, h_ratio: float, v_ratio: float) -> float:
136
+ h_offset = abs(h_ratio - 0.5)
137
+ v_offset = abs(v_ratio - 0.5)
138
+ offset = math.sqrt(h_offset**2 + v_offset**2)
139
+ t = min(offset / self.gaze_max_offset, 1.0)
140
+ return 0.5 * (1.0 + math.cos(math.pi * t))
141
+
142
+ def score(self, landmarks: np.ndarray) -> float:
143
+ ear = compute_avg_ear(landmarks)
144
+ ear_s = self._ear_score(ear)
145
+ if ear_s < 0.3:
146
+ return ear_s
147
+ h_ratio, v_ratio = compute_gaze_ratio(landmarks)
148
+ gaze_s = self._gaze_score(h_ratio, v_ratio)
149
+ return ear_s * gaze_s
150
+
151
+ def detailed_score(self, landmarks: np.ndarray) -> dict:
152
+ ear = compute_avg_ear(landmarks)
153
+ ear_s = self._ear_score(ear)
154
+ h_ratio, v_ratio = compute_gaze_ratio(landmarks)
155
+ gaze_s = self._gaze_score(h_ratio, v_ratio)
156
+ s_eye = ear_s if ear_s < 0.3 else ear_s * gaze_s
157
+ return {
158
+ "ear": round(ear, 4),
159
+ "ear_score": round(ear_s, 4),
160
+ "h_gaze": round(h_ratio, 4),
161
+ "v_gaze": round(v_ratio, 4),
162
+ "gaze_score": round(gaze_s, 4),
163
+ "s_eye": round(s_eye, 4),
164
+ }
models/geometric/face_orientation/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
models/geometric/face_orientation/head_pose.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import cv2
4
+ import numpy as np
5
+
6
+ _LANDMARK_INDICES = [1, 152, 33, 263, 61, 291]
7
+
8
+ _MODEL_POINTS = np.array(
9
+ [
10
+ [0.0, 0.0, 0.0],
11
+ [0.0, -330.0, -65.0],
12
+ [-225.0, 170.0, -135.0],
13
+ [225.0, 170.0, -135.0],
14
+ [-150.0, -150.0, -125.0],
15
+ [150.0, -150.0, -125.0],
16
+ ],
17
+ dtype=np.float64,
18
+ )
19
+
20
+
21
+ class HeadPoseEstimator:
22
+ def __init__(self, max_angle: float = 30.0, roll_weight: float = 0.5):
23
+ self.max_angle = max_angle
24
+ self.roll_weight = roll_weight
25
+ self._camera_matrix = None
26
+ self._frame_size = None
27
+ self._dist_coeffs = np.zeros((4, 1), dtype=np.float64)
28
+
29
+ def _get_camera_matrix(self, frame_w: int, frame_h: int) -> np.ndarray:
30
+ if self._camera_matrix is not None and self._frame_size == (frame_w, frame_h):
31
+ return self._camera_matrix
32
+ focal_length = float(frame_w)
33
+ cx, cy = frame_w / 2.0, frame_h / 2.0
34
+ self._camera_matrix = np.array(
35
+ [[focal_length, 0, cx], [0, focal_length, cy], [0, 0, 1]],
36
+ dtype=np.float64,
37
+ )
38
+ self._frame_size = (frame_w, frame_h)
39
+ return self._camera_matrix
40
+
41
+ def _solve(self, landmarks: np.ndarray, frame_w: int, frame_h: int):
42
+ image_points = np.array(
43
+ [
44
+ [landmarks[i, 0] * frame_w, landmarks[i, 1] * frame_h]
45
+ for i in _LANDMARK_INDICES
46
+ ],
47
+ dtype=np.float64,
48
+ )
49
+ camera_matrix = self._get_camera_matrix(frame_w, frame_h)
50
+ success, rvec, tvec = cv2.solvePnP(
51
+ _MODEL_POINTS,
52
+ image_points,
53
+ camera_matrix,
54
+ self._dist_coeffs,
55
+ flags=cv2.SOLVEPNP_ITERATIVE,
56
+ )
57
+ return success, rvec, tvec, image_points
58
+
59
+ def estimate(
60
+ self, landmarks: np.ndarray, frame_w: int, frame_h: int
61
+ ) -> tuple[float, float, float] | None:
62
+ success, rvec, tvec, _ = self._solve(landmarks, frame_w, frame_h)
63
+ if not success:
64
+ return None
65
+
66
+ rmat, _ = cv2.Rodrigues(rvec)
67
+ nose_dir = rmat @ np.array([0.0, 0.0, 1.0])
68
+ face_up = rmat @ np.array([0.0, 1.0, 0.0])
69
+
70
+ yaw = math.degrees(math.atan2(nose_dir[0], -nose_dir[2]))
71
+ pitch = math.degrees(math.asin(np.clip(-nose_dir[1], -1.0, 1.0)))
72
+ roll = math.degrees(math.atan2(face_up[0], -face_up[1]))
73
+
74
+ return (yaw, pitch, roll)
75
+
76
+ def score(self, landmarks: np.ndarray, frame_w: int, frame_h: int) -> float:
77
+ angles = self.estimate(landmarks, frame_w, frame_h)
78
+ if angles is None:
79
+ return 0.0
80
+
81
+ yaw, pitch, roll = angles
82
+ deviation = math.sqrt(yaw**2 + pitch**2 + (self.roll_weight * roll) ** 2)
83
+ t = min(deviation / self.max_angle, 1.0)
84
+ return 0.5 * (1.0 + math.cos(math.pi * t))
85
+
86
+ def draw_axes(
87
+ self,
88
+ frame: np.ndarray,
89
+ landmarks: np.ndarray,
90
+ axis_length: float = 50.0,
91
+ ) -> np.ndarray:
92
+ h, w = frame.shape[:2]
93
+ success, rvec, tvec, image_points = self._solve(landmarks, w, h)
94
+ if not success:
95
+ return frame
96
+
97
+ camera_matrix = self._get_camera_matrix(w, h)
98
+ nose = tuple(image_points[0].astype(int))
99
+
100
+ axes_3d = np.float64(
101
+ [[axis_length, 0, 0], [0, axis_length, 0], [0, 0, axis_length]]
102
+ )
103
+ projected, _ = cv2.projectPoints(
104
+ axes_3d, rvec, tvec, camera_matrix, self._dist_coeffs
105
+ )
106
+
107
+ colors = [(0, 0, 255), (0, 255, 0), (255, 0, 0)]
108
+ for i, color in enumerate(colors):
109
+ pt = tuple(projected[i].ravel().astype(int))
110
+ cv2.line(frame, nose, pt, color, 2)
111
+
112
+ return frame
models/mlp/__init__.py ADDED
File without changes
models/{train.py → mlp/train.py} RENAMED
@@ -1,18 +1,18 @@
1
- # Run from repo root: python -m models.train (or cd models && python train.py)
2
-
3
  import json
4
- import os
5
  import random
6
 
7
- import numpy as np as np
8
  import torch
9
  import torch.nn as nn
10
  import torch.optim as optim
11
 
12
- from prepare_dataset import get_dataloaders
 
 
13
 
14
  CFG = {
15
- "model_name": "face_orientation", # "face_orientation" or "eye_behaviour"
16
  "epochs": 30,
17
  "batch_size": 32,
18
  "lr": 1e-3,
@@ -22,10 +22,25 @@ CFG = {
22
  "face_orientation": os.path.join(os.path.dirname(__file__), "face_orientation_model"),
23
  "eye_behaviour": os.path.join(os.path.dirname(__file__), "eye_behaviour_model"),
24
  },
25
- "logs_dir": os.path.join(os.path.dirname(__file__), "..", "evaluation", "logs"),
26
  }
27
 
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  def set_seed(seed: int):
30
  random.seed(seed)
31
  np.random.seed(seed)
@@ -154,6 +169,16 @@ def main():
154
  history["val_loss"].append(round(val_loss, 4))
155
  history["val_acc"].append(round(val_acc, 4))
156
 
 
 
 
 
 
 
 
 
 
 
157
  marker = ""
158
  if val_acc > best_val_acc:
159
  best_val_acc = val_acc
 
 
 
1
  import json
2
+ import os, sys
3
  import random
4
 
5
+ import numpy as np
6
  import torch
7
  import torch.nn as nn
8
  import torch.optim as optim
9
 
10
+ from clearml import Task
11
+
12
+ from models.prepare_dataset import get_dataloaders
13
 
14
  CFG = {
15
+ "model_name": "face_orientation",
16
  "epochs": 30,
17
  "batch_size": 32,
18
  "lr": 1e-3,
 
22
  "face_orientation": os.path.join(os.path.dirname(__file__), "face_orientation_model"),
23
  "eye_behaviour": os.path.join(os.path.dirname(__file__), "eye_behaviour_model"),
24
  },
25
+ "logs_dir": os.path.join(os.path.dirname(__file__), "..", "..", "evaluation", "logs"),
26
  }
27
 
28
 
29
+ # ==== ClearML Initialisation =============================================
30
+ task = Task.init(
31
+ project_name="Focus Guard",
32
+ task_name=f"MLP Model Training",
33
+ tags=["training", "mlp_model"]
34
+ )
35
+
36
+ prefix = 'checkpoints/'+task.name+'_'+task.id+'/'
37
+ os.makedirs(prefix, exist_ok=True)
38
+
39
+ task.connect(CFG)
40
+
41
+
42
+
43
+ # ==== Model =============================================
44
  def set_seed(seed: int):
45
  random.seed(seed)
46
  np.random.seed(seed)
 
169
  history["val_loss"].append(round(val_loss, 4))
170
  history["val_acc"].append(round(val_acc, 4))
171
 
172
+
173
+ # Log scalars to ClearML
174
+ current_lr = optimizer.param_groups[0]['lr']
175
+ task.logger.report_scalar("Loss", "Train", float(train_loss), iteration=epoch)
176
+ task.logger.report_scalar("Accuracy", "Train", float(train_acc), iteration=epoch)
177
+ task.logger.report_scalar("Loss", "Val", float(val_loss), iteration=epoch)
178
+ task.logger.report_scalar("Accuracy", "Val", float(val_acc), iteration=epoch)
179
+ task.logger.report_scalar("Learning Rate", "LR", float(current_lr), iteration=epoch)
180
+ task.logger.flush()
181
+
182
  marker = ""
183
  if val_acc > best_val_acc:
184
  best_val_acc = val_acc
models/pretrained/__init__.py ADDED
File without changes