Yingtao-Zheng commited on
Commit
790e2c2
·
1 Parent(s): 6d2f7ec

Merge testing scripts from feauture/intergration

Browse files
tests/test_cnn_eye_attention_classifier.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ import numpy as np
5
+ import pytest
6
+
7
+ PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
8
+ if PROJECT_ROOT not in sys.path:
9
+ sys.path.insert(0, PROJECT_ROOT)
10
+
11
+
12
+ def test_load_eye_classifier_geometric_backend():
13
+ from models.cnn.eye_attention.classifier import load_eye_classifier, GeometricOnlyClassifier
14
+
15
+ clf = load_eye_classifier(path="anything.pt", backend="geometric")
16
+ # Assert that the returned object is indeed a GeometricOnlyClassifier
17
+ assert isinstance(clf, GeometricOnlyClassifier)
18
+ assert clf.name == "geometric"
19
+ # Assert that it can successfully process a dummy black image and return a score
20
+ assert clf.predict_score([np.zeros((10, 10, 3), dtype=np.uint8)]) == 1.0
21
+
22
+
23
+ def test_load_eye_classifier_none_path_falls_back_to_geometric(capsys):
24
+ from models.cnn.eye_attention.classifier import load_eye_classifier, GeometricOnlyClassifier
25
+ # Intentionally request the 'yolo' backend but provide no model path
26
+ clf = load_eye_classifier(path=None, backend="yolo")
27
+ assert isinstance(clf, GeometricOnlyClassifier)
28
+ out = capsys.readouterr().out
29
+ assert "falling back to geometric" in out
30
+
31
+
32
+ def test_load_eye_classifier_extension_overrides_backend(monkeypatch, tmp_path):
33
+ """
34
+ .pth -> cnn, .pt -> yolo(但不真的加载 torch/ultralytics)
35
+ """
36
+ import models.cnn.eye_attention.classifier as m
37
+
38
+ created = {}
39
+ #Create dummy model classes to mock real deep learning models
40
+ class DummyCNN(m.EyeClassifier):
41
+ @property
42
+ def name(self) -> str:
43
+ return "dummy_cnn"
44
+
45
+ def predict_score(self, crops_bgr):
46
+ return 0.5
47
+
48
+ class DummyYOLO(m.EyeClassifier):
49
+ @property
50
+ def name(self) -> str:
51
+ return "dummy_yolo"
52
+
53
+ def predict_score(self, crops_bgr):
54
+ return 0.7
55
+ # Intercept the original constructors to return our dummy models
56
+ def fake_cnn_ctor(path, device="cpu"):
57
+ created["cnn"] = (path, device)
58
+ return DummyCNN()
59
+
60
+ def fake_yolo_ctor(path, device="cpu"):
61
+ created["yolo"] = (path, device)
62
+ return DummyYOLO()
63
+
64
+ monkeypatch.setattr(m, "EyeCNNClassifier", fake_cnn_ctor)
65
+ monkeypatch.setattr(m, "YOLOv11Classifier", fake_yolo_ctor)
66
+ #Test the '.pth' extension logic
67
+ pth = tmp_path / "eye_model.pth"
68
+ pth.write_bytes(b"not a real checkpoint")
69
+ clf = m.load_eye_classifier(path=str(pth), backend="yolo", device="cpu")
70
+ assert clf.name == "dummy_cnn"
71
+ assert "cnn" in created and created["cnn"][0] == str(pth)
72
+ #Test the '.pt' extension logic
73
+ pt = tmp_path / "eye_model.pt"
74
+ pt.write_bytes(b"not a real yolo model")
75
+ clf2 = m.load_eye_classifier(path=str(pt), backend="cnn", device="cpu")
76
+ assert clf2.name == "dummy_yolo"
77
+ assert "yolo" in created and created["yolo"][0] == str(pt)
78
+
79
+
80
+ def test_load_eye_classifier_unknown_backend_raises():
81
+ from models.cnn.eye_attention.classifier import load_eye_classifier
82
+
83
+ with pytest.raises(ValueError):
84
+ # Use an unknown extension to prevent the automatic fallback mechanism
85
+ load_eye_classifier(path="x.unknownext", backend="unknown")
86
+
tests/test_cnn_eye_attention_crop.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import types
4
+ import importlib
5
+
6
+ import numpy as np
7
+
8
+ PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
9
+ if PROJECT_ROOT not in sys.path:
10
+ sys.path.insert(0, PROJECT_ROOT)
11
+
12
+
13
+ def _install_fake_facemesh_module():
14
+ pkg_models = sys.modules.setdefault("models", types.ModuleType("models"))
15
+ pkg_pretrained = sys.modules.setdefault(
16
+ "models.pretrained", types.ModuleType("models.pretrained")
17
+ )
18
+ pkg_face_mesh = sys.modules.setdefault(
19
+ "models.pretrained.face_mesh", types.ModuleType("models.pretrained.face_mesh")
20
+ )
21
+ mod = types.ModuleType("models.pretrained.face_mesh.face_mesh")
22
+
23
+ class FaceMeshDetector:
24
+ # get landmarks
25
+ LEFT_EYE_INDICES = [0, 1, 2]
26
+ RIGHT_EYE_INDICES = [3, 4, 5]
27
+
28
+ mod.FaceMeshDetector = FaceMeshDetector
29
+
30
+ sys.modules["models.pretrained.face_mesh.face_mesh"] = mod
31
+ pkg_models.pretrained = pkg_pretrained
32
+ pkg_pretrained.face_mesh = pkg_face_mesh
33
+ pkg_face_mesh.face_mesh = mod
34
+
35
+
36
+ def test_bbox_from_landmarks_clamps_to_frame():
37
+ _install_fake_facemesh_module()
38
+ crop = importlib.import_module("models.cnn.eye_attention.crop")
39
+ importlib.reload(crop)
40
+
41
+ # normalize the cord
42
+ landmarks = np.array(
43
+ [
44
+ [0.1, 0.1, 0.0],
45
+ [0.2, 0.1, 0.0],
46
+ [0.15, 0.2, 0.0],
47
+ [0.8, 0.1, 0.0],
48
+ [0.9, 0.1, 0.0],
49
+ [0.85, 0.2, 0.0],
50
+ ],
51
+ dtype=np.float32,
52
+ )
53
+
54
+ x1, y1, x2, y2 = crop._bbox_from_landmarks(landmarks, [0, 1, 2], 100, 50, expand=0.0)
55
+ assert 0 <= x1 < x2 <= 100
56
+ assert 0 <= y1 < y2 <= 50
57
+
58
+
59
+ def test_extract_eye_crops_returns_expected_shapes():
60
+ _install_fake_facemesh_module()
61
+ crop = importlib.import_module("models.cnn.eye_attention.crop")
62
+ importlib.reload(crop)
63
+
64
+ frame = np.zeros((60, 120, 3), dtype=np.uint8)
65
+ landmarks = np.array(
66
+ [
67
+ [0.1, 0.2, 0.0],
68
+ [0.2, 0.2, 0.0],
69
+ [0.15, 0.3, 0.0],
70
+ [0.7, 0.2, 0.0],
71
+ [0.8, 0.2, 0.0],
72
+ [0.75, 0.3, 0.0],
73
+ ],
74
+ dtype=np.float32,
75
+ )
76
+
77
+ left, right, left_bbox, right_bbox = crop.extract_eye_crops(
78
+ frame, landmarks, expand=0.0, crop_size=32
79
+ )
80
+ assert left.shape == (32, 32, 3)
81
+ assert right.shape == (32, 32, 3)
82
+ assert len(left_bbox) == 4
83
+ assert len(right_bbox) == 4
84
+
tests/test_pipeline_integration.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ import numpy as np
5
+
6
+ PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
7
+ if PROJECT_ROOT not in sys.path:
8
+ sys.path.insert(0, PROJECT_ROOT)
9
+
10
+ from ui.pipeline import FaceMeshPipeline
11
+
12
+
13
+ class _DummyDetector:
14
+ def __init__(self, landmarks=None):
15
+ self._landmarks = landmarks
16
+
17
+ def process(self, bgr_frame):
18
+ return self._landmarks
19
+
20
+ def close(self):
21
+ return None
22
+
23
+
24
+ def test_face_mesh_pipeline_no_face_returns_expected_keys():
25
+ pipe = FaceMeshPipeline(detector=_DummyDetector(landmarks=None), eye_backend="geometric")
26
+ frame = np.zeros((480, 640, 3), dtype=np.uint8)
27
+ out = pipe.process_frame(frame)
28
+
29
+ assert isinstance(out, dict)
30
+ for k in ("landmarks", "s_face", "s_eye", "raw_score", "is_focused", "yaw", "pitch", "roll", "mar", "is_yawning"):
31
+ assert k in out
32
+ assert out["landmarks"] is None
33
+ assert 0.0 <= float(out["raw_score"]) <= 1.0
34
+
35
+
36
+ def test_face_mesh_pipeline_with_fake_landmarks_runs():
37
+ fake_lm = np.zeros((478, 2), dtype=np.float32)
38
+ pipe = FaceMeshPipeline(detector=_DummyDetector(landmarks=fake_lm), eye_backend="geometric")
39
+ frame = np.zeros((480, 640, 3), dtype=np.uint8)
40
+ out = pipe.process_frame(frame)
41
+
42
+ assert out["landmarks"] is not None
43
+ assert "is_focused" in out
44
+ assert "raw_score" in out
45
+ assert 0.0 <= float(out["raw_score"]) <= 1.0
46
+