k22056537 commited on
Commit
da26163
·
1 Parent(s): 76adc7f

feat: add optional eye model (YOLO/MobileNet) alongside geometry

Browse files
models/eye_behaviour/eye_classifier.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Swappable eye classifier: geometric only, MobileNetV2 (96x96), or YOLO open/closed (224x224)
2
+
3
+ from __future__ import annotations
4
+
5
+ from abc import ABC, abstractmethod
6
+
7
+ import cv2
8
+ import numpy as np
9
+
10
+
11
+ class EyeClassifier(ABC):
12
+ @property
13
+ @abstractmethod
14
+ def name(self) -> str:
15
+ pass
16
+
17
+ @abstractmethod
18
+ def predict_score(self, crops_bgr: list[np.ndarray]) -> float:
19
+ # crops_bgr: [left_crop, right_crop] BGR; returns score in [0,1], 1 = attentive (open)
20
+ pass
21
+
22
+
23
+ class GeometricOnlyClassifier(EyeClassifier):
24
+ @property
25
+ def name(self) -> str:
26
+ return "geometric"
27
+
28
+ def predict_score(self, crops_bgr: list[np.ndarray]) -> float:
29
+ return 1.0
30
+
31
+
32
+ class MobileNetV2Classifier(EyeClassifier):
33
+ # 96x96 crops, ImageNet norm
34
+ def __init__(self, checkpoint_path: str, device: str = "cpu"):
35
+ import torch
36
+
37
+ from models.eye_behaviour.eye_attention_model import EyeAttentionModel
38
+ from models.eye_behaviour.eye_crop import crop_to_tensor, CROP_SIZE
39
+
40
+ self._crop_to_tensor = crop_to_tensor
41
+ self._crop_size = CROP_SIZE
42
+ self._device = torch.device(device)
43
+
44
+ self._model = EyeAttentionModel(pretrained=False).to(self._device)
45
+ self._model.load_state_dict(
46
+ torch.load(checkpoint_path, map_location=self._device, weights_only=True)
47
+ )
48
+ self._model.eval()
49
+
50
+ @property
51
+ def name(self) -> str:
52
+ return "mobilenet"
53
+
54
+ def predict_score(self, crops_bgr: list[np.ndarray]) -> float:
55
+ import torch
56
+
57
+ if not crops_bgr:
58
+ return 1.0
59
+ tensors = []
60
+ for crop in crops_bgr:
61
+ resized = cv2.resize(crop, (self._crop_size, self._crop_size), interpolation=cv2.INTER_AREA)
62
+ tensors.append(self._crop_to_tensor(resized))
63
+ batch = torch.stack(tensors).to(self._device)
64
+ with torch.no_grad():
65
+ scores = self._model.predict_score(batch)
66
+ return scores.mean().item()
67
+
68
+
69
+ class YOLOv11Classifier(EyeClassifier):
70
+ # YOLO open/closed; resizes to 224x224 internally
71
+ def __init__(self, checkpoint_path: str, device: str = "cpu"):
72
+ from ultralytics import YOLO
73
+
74
+ self._model = YOLO(checkpoint_path)
75
+ self._device = device
76
+
77
+ names = self._model.names
78
+ self._attentive_idx = None
79
+ for idx, cls_name in names.items():
80
+ if cls_name in ("open", "attentive"):
81
+ self._attentive_idx = idx
82
+ break
83
+ if self._attentive_idx is None:
84
+ self._attentive_idx = max(names.keys())
85
+ print(f"[YOLO] Classes: {names}, attentive_idx={self._attentive_idx}")
86
+
87
+ @property
88
+ def name(self) -> str:
89
+ return "yolo"
90
+
91
+ def predict_score(self, crops_bgr: list[np.ndarray]) -> float:
92
+ if not crops_bgr:
93
+ return 1.0
94
+ results = self._model.predict(crops_bgr, device=self._device, verbose=False)
95
+ scores = [float(r.probs.data[self._attentive_idx]) for r in results]
96
+ return sum(scores) / len(scores) if scores else 1.0
97
+
98
+
99
+ def _is_yolo_checkpoint(path: str) -> bool:
100
+ try:
101
+ import torch
102
+
103
+ data = torch.load(path, map_location="cpu", weights_only=False)
104
+ if isinstance(data, dict):
105
+ model_obj = data.get("model")
106
+ if model_obj is not None and "Model" in type(model_obj).__name__:
107
+ return True
108
+ if "train_args" in data and "model" in data:
109
+ return True
110
+ except Exception:
111
+ pass
112
+ return False
113
+
114
+
115
+ def load_eye_classifier(
116
+ path: str | None = None,
117
+ backend: str = "auto",
118
+ device: str = "cpu",
119
+ ) -> EyeClassifier:
120
+ if path is None or backend == "geometric":
121
+ return GeometricOnlyClassifier()
122
+
123
+ if backend == "yolo":
124
+ try:
125
+ return YOLOv11Classifier(path, device=device)
126
+ except ImportError:
127
+ print("[CLASSIFIER] ultralytics required. pip install ultralytics")
128
+ raise
129
+
130
+ if backend == "mobilenet":
131
+ return MobileNetV2Classifier(path, device=device)
132
+
133
+ if _is_yolo_checkpoint(path):
134
+ try:
135
+ return YOLOv11Classifier(path, device=device)
136
+ except ImportError:
137
+ print("[CLASSIFIER] YOLO checkpoint needs ultralytics. pip install ultralytics")
138
+ raise
139
+ try:
140
+ return MobileNetV2Classifier(path, device=device)
141
+ except Exception as exc:
142
+ err = str(exc)
143
+ if "Weights only load failed" in err and "ultralytics" in err:
144
+ try:
145
+ return YOLOv11Classifier(path, device=device)
146
+ except ImportError:
147
+ print("[CLASSIFIER] pip install ultralytics for this checkpoint")
148
+ raise
149
+ raise
requirements.txt CHANGED
@@ -4,3 +4,4 @@ opencv-python>=4.8.0
4
  numpy>=1.24.0
5
  torch>=2.0.0
6
  torchvision>=0.15.0
 
 
4
  numpy>=1.24.0
5
  torch>=2.0.0
6
  torchvision>=0.15.0
7
+ # ultralytics # optional: for YOLO open/closed eye classifier
ui/README.md CHANGED
@@ -2,14 +2,21 @@
2
 
3
  Live demo and session view.
4
 
5
- ## Stage 1 (face mesh only)
6
 
7
- - **pipeline.py** — frame478 landmarks (no head pose / CNN).
8
- - **live_demo.py** — webcam + mesh overlay (tessellation, contours, eyes, irises).
9
 
10
  From repo root:
11
  ```bash
12
  pip install -r requirements.txt
13
  python ui/live_demo.py
14
  ```
 
 
 
 
 
 
 
15
  `q` = quit, `m` = cycle mesh mode (full / contours / off).
 
2
 
3
  Live demo and session view.
4
 
5
+ ## Stage 2 (face mesh + head pose + eye)
6
 
7
+ - **pipeline.py** — face mesh S_face (head pose) + S_eye (geometry + optional YOLO/MobileNet) + MAR/yawn → focus.
8
+ - **live_demo.py** — webcam + mesh, FOCUSED/NOT FOCUSED, MAR, YAWN, optional eye model.
9
 
10
  From repo root:
11
  ```bash
12
  pip install -r requirements.txt
13
  python ui/live_demo.py
14
  ```
15
+ With YOLO open/closed model (face mesh crops eyes → 224×224 → YOLO):
16
+ ```bash
17
+ pip install ultralytics
18
+ python ui/live_demo.py --eye-model path/to/yolo.pt --eye-backend yolo
19
+ ```
20
+ With MobileNetV2 (96×96 crops): `--eye-model path/to/best_model.pt --eye-backend mobilenet`.
21
+
22
  `q` = quit, `m` = cycle mesh mode (full / contours / off).
ui/live_demo.py CHANGED
@@ -125,10 +125,22 @@ def main():
125
  parser.add_argument("--alpha", type=float, default=0.4, help="S_face weight")
126
  parser.add_argument("--beta", type=float, default=0.6, help="S_eye weight")
127
  parser.add_argument("--threshold", type=float, default=0.55, help="Score >= this = FOCUSED (higher = stricter)")
 
 
 
128
  args = parser.parse_args()
129
 
130
- print("[DEMO] Face mesh + head pose + eye behaviour (Stage 2)")
131
- pipeline = FaceMeshPipeline(max_angle=args.max_angle, alpha=args.alpha, beta=args.beta, threshold=args.threshold)
 
 
 
 
 
 
 
 
 
132
 
133
  cap = cv2.VideoCapture(args.camera)
134
  if not cap.isOpened():
@@ -161,6 +173,11 @@ def main():
161
  draw_contours(frame, lm, w, h)
162
  draw_eyes_and_irises(frame, lm, w, h)
163
  pipeline.head_pose.draw_axes(frame, lm)
 
 
 
 
 
164
 
165
  # Status bar: FOCUSED / NOT FOCUSED; YAWN when mouth open (sleepy)
166
  status = "FOCUSED" if result["is_focused"] else "NOT FOCUSED"
@@ -173,7 +190,8 @@ def main():
173
  cv2.putText(frame, "YAWN", (10, 75), FONT, 0.7, ORANGE, 2, cv2.LINE_AA)
174
  if result["yaw"] is not None:
175
  cv2.putText(frame, f"yaw:{result['yaw']:+.0f} pitch:{result['pitch']:+.0f} roll:{result['roll']:+.0f}", (w - 280, 48), FONT, 0.4, (180, 180, 180), 1, cv2.LINE_AA)
176
- cv2.putText(frame, f"{_MESH_NAMES[mesh_mode]} FPS: {fps:.0f}", (w - 200, 28), FONT, 0.45, WHITE, 1, cv2.LINE_AA)
 
177
  cv2.putText(frame, "q:quit m:mesh", (w - 140, 48), FONT, 0.4, (180, 180, 180), 1, cv2.LINE_AA)
178
 
179
  cv2.imshow("FocusGuard", frame)
 
125
  parser.add_argument("--alpha", type=float, default=0.4, help="S_face weight")
126
  parser.add_argument("--beta", type=float, default=0.6, help="S_eye weight")
127
  parser.add_argument("--threshold", type=float, default=0.55, help="Score >= this = FOCUSED (higher = stricter)")
128
+ parser.add_argument("--eye-model", type=str, default=None, help="Path to eye model (YOLO .pt or MobileNet .pt); omit = geometry only")
129
+ parser.add_argument("--eye-backend", type=str, default="auto", choices=["auto", "mobilenet", "yolo", "geometric"], help="Eye model backend (auto = detect from file)")
130
+ parser.add_argument("--eye-blend", type=float, default=0.5, help="Blend: (1-blend)*geo + blend*model when model loaded")
131
  args = parser.parse_args()
132
 
133
+ eye_mode = " + model" if args.eye_model else " only"
134
+ print("[DEMO] Face mesh + head pose + eye (geometry" + eye_mode + ")")
135
+ pipeline = FaceMeshPipeline(
136
+ max_angle=args.max_angle,
137
+ alpha=args.alpha,
138
+ beta=args.beta,
139
+ threshold=args.threshold,
140
+ eye_model_path=args.eye_model,
141
+ eye_backend=args.eye_backend,
142
+ eye_blend=args.eye_blend,
143
+ )
144
 
145
  cap = cv2.VideoCapture(args.camera)
146
  if not cap.isOpened():
 
173
  draw_contours(frame, lm, w, h)
174
  draw_eyes_and_irises(frame, lm, w, h)
175
  pipeline.head_pose.draw_axes(frame, lm)
176
+ if result.get("left_bbox") and result.get("right_bbox"):
177
+ lx1, ly1, lx2, ly2 = result["left_bbox"]
178
+ rx1, ry1, rx2, ry2 = result["right_bbox"]
179
+ cv2.rectangle(frame, (lx1, ly1), (lx2, ly2), YELLOW, 1)
180
+ cv2.rectangle(frame, (rx1, ry1), (rx2, ry2), YELLOW, 1)
181
 
182
  # Status bar: FOCUSED / NOT FOCUSED; YAWN when mouth open (sleepy)
183
  status = "FOCUSED" if result["is_focused"] else "NOT FOCUSED"
 
190
  cv2.putText(frame, "YAWN", (10, 75), FONT, 0.7, ORANGE, 2, cv2.LINE_AA)
191
  if result["yaw"] is not None:
192
  cv2.putText(frame, f"yaw:{result['yaw']:+.0f} pitch:{result['pitch']:+.0f} roll:{result['roll']:+.0f}", (w - 280, 48), FONT, 0.4, (180, 180, 180), 1, cv2.LINE_AA)
193
+ eye_label = f"eye:{pipeline.eye_classifier.name}" if pipeline.has_eye_model else "eye:geo"
194
+ cv2.putText(frame, f"{_MESH_NAMES[mesh_mode]} {eye_label} FPS: {fps:.0f}", (w - 320, 28), FONT, 0.45, WHITE, 1, cv2.LINE_AA)
195
  cv2.putText(frame, "q:quit m:mesh", (w - 140, 48), FONT, 0.4, (180, 180, 180), 1, cv2.LINE_AA)
196
 
197
  cv2.imshow("FocusGuard", frame)
ui/pipeline.py CHANGED
@@ -1,4 +1,4 @@
1
- # Stage 2: face mesh + head pose (S_face) + eye behaviour (S_eye) -> focus
2
 
3
  import os
4
  import sys
@@ -12,18 +12,39 @@ if _PROJECT_ROOT not in sys.path:
12
  from models.face_mesh.face_mesh import FaceMeshDetector
13
  from models.face_orientation.head_pose import HeadPoseEstimator
14
  from models.eye_behaviour.eye_scorer import EyeBehaviourScorer, compute_mar, MAR_YAWN_THRESHOLD
 
 
15
 
16
 
17
  class FaceMeshPipeline:
18
- # frame -> face mesh -> S_face + S_eye -> focused / not focused
19
-
20
- def __init__(self, max_angle: float = 22.0, alpha: float = 0.4, beta: float = 0.6, threshold: float = 0.55):
 
 
 
 
 
 
 
 
 
21
  self.detector = FaceMeshDetector()
22
  self.head_pose = HeadPoseEstimator(max_angle=max_angle)
23
  self.eye_scorer = EyeBehaviourScorer()
24
  self.alpha = alpha
25
  self.beta = beta
26
  self.threshold = threshold
 
 
 
 
 
 
 
 
 
 
27
 
28
  def process_frame(self, bgr_frame: np.ndarray) -> dict:
29
  landmarks = self.detector.process(bgr_frame)
@@ -40,6 +61,8 @@ class FaceMeshPipeline:
40
  "roll": None,
41
  "mar": None,
42
  "is_yawning": False,
 
 
43
  }
44
 
45
  if landmarks is None:
@@ -51,19 +74,31 @@ class FaceMeshPipeline:
51
  out["yaw"], out["pitch"], out["roll"] = angles
52
  out["s_face"] = self.head_pose.score(landmarks, w, h)
53
 
54
- # Eye behaviour (EAR + gaze) -> S_eye
55
- out["s_eye"] = self.eye_scorer.score(landmarks)
56
-
57
- # Mouth open (MAR) -> yawn / sleepy: force NOT FOCUSED when mouth open
 
 
 
 
 
 
 
 
58
  out["mar"] = compute_mar(landmarks)
59
  out["is_yawning"] = out["mar"] > MAR_YAWN_THRESHOLD
60
 
61
- # Fusion: alpha*S_face + beta*S_eye; if yawning (mouth open) -> not focused
62
  out["raw_score"] = self.alpha * out["s_face"] + self.beta * out["s_eye"]
63
  out["is_focused"] = out["raw_score"] >= self.threshold and not out["is_yawning"]
64
 
65
  return out
66
 
 
 
 
 
67
  def close(self):
68
  self.detector.close()
69
 
 
1
+ # Stage 2: face mesh + head pose (S_face) + eye (geometry + optional model) -> focus
2
 
3
  import os
4
  import sys
 
12
  from models.face_mesh.face_mesh import FaceMeshDetector
13
  from models.face_orientation.head_pose import HeadPoseEstimator
14
  from models.eye_behaviour.eye_scorer import EyeBehaviourScorer, compute_mar, MAR_YAWN_THRESHOLD
15
+ from models.eye_behaviour.eye_crop import extract_eye_crops
16
+ from models.eye_behaviour.eye_classifier import load_eye_classifier, GeometricOnlyClassifier
17
 
18
 
19
  class FaceMeshPipeline:
20
+ # frame -> face mesh -> S_face + S_eye (geo + optional YOLO/MobileNet) -> focused / not focused
21
+
22
+ def __init__(
23
+ self,
24
+ max_angle: float = 22.0,
25
+ alpha: float = 0.4,
26
+ beta: float = 0.6,
27
+ threshold: float = 0.55,
28
+ eye_model_path: str | None = None,
29
+ eye_backend: str = "auto",
30
+ eye_blend: float = 0.5,
31
+ ):
32
  self.detector = FaceMeshDetector()
33
  self.head_pose = HeadPoseEstimator(max_angle=max_angle)
34
  self.eye_scorer = EyeBehaviourScorer()
35
  self.alpha = alpha
36
  self.beta = beta
37
  self.threshold = threshold
38
+ self.eye_blend = eye_blend # 0.5 = 50% geo + 50% model when model loaded
39
+
40
+ self.eye_classifier = load_eye_classifier(
41
+ path=eye_model_path if eye_model_path and os.path.exists(eye_model_path) else None,
42
+ backend=eye_backend,
43
+ device="cpu",
44
+ )
45
+ self._has_eye_model = not isinstance(self.eye_classifier, GeometricOnlyClassifier)
46
+ if self._has_eye_model:
47
+ print(f"[PIPELINE] Eye model: {self.eye_classifier.name}")
48
 
49
  def process_frame(self, bgr_frame: np.ndarray) -> dict:
50
  landmarks = self.detector.process(bgr_frame)
 
61
  "roll": None,
62
  "mar": None,
63
  "is_yawning": False,
64
+ "left_bbox": None,
65
+ "right_bbox": None,
66
  }
67
 
68
  if landmarks is None:
 
74
  out["yaw"], out["pitch"], out["roll"] = angles
75
  out["s_face"] = self.head_pose.score(landmarks, w, h)
76
 
77
+ # Eye: geometry (EAR + gaze) always; optional model (YOLO/MobileNet) on cropped eyes
78
+ s_eye_geo = self.eye_scorer.score(landmarks)
79
+ if self._has_eye_model:
80
+ left_crop, right_crop, left_bbox, right_bbox = extract_eye_crops(bgr_frame, landmarks)
81
+ out["left_bbox"] = left_bbox
82
+ out["right_bbox"] = right_bbox
83
+ s_eye_model = self.eye_classifier.predict_score([left_crop, right_crop])
84
+ out["s_eye"] = (1.0 - self.eye_blend) * s_eye_geo + self.eye_blend * s_eye_model
85
+ else:
86
+ out["s_eye"] = s_eye_geo
87
+
88
+ # Mouth open (MAR) -> yawn: force NOT FOCUSED when mouth open
89
  out["mar"] = compute_mar(landmarks)
90
  out["is_yawning"] = out["mar"] > MAR_YAWN_THRESHOLD
91
 
92
+ # Fusion; yawn overrides
93
  out["raw_score"] = self.alpha * out["s_face"] + self.beta * out["s_eye"]
94
  out["is_focused"] = out["raw_score"] >= self.threshold and not out["is_yawning"]
95
 
96
  return out
97
 
98
+ @property
99
+ def has_eye_model(self) -> bool:
100
+ return self._has_eye_model
101
+
102
  def close(self):
103
  self.detector.close()
104