Abdelrahman Almatrooshi commited on
Commit
2eba0cc
·
1 Parent(s): fad97ce

Integrate L2CS-Net gaze estimation

Browse files

- Add L2CS-Net in-tree (models/L2CS-Net/) with Gaze360 weights via Git LFS
- L2CSPipeline: ResNet50 gaze + MediaPipe head pose, roll de-rotation, cosine scoring
- 9-point polynomial gaze calibration with bias correction and IQR outlier filtering
- Gaze-eye fusion: calibrated screen coords + EAR for focus detection
- L2CS Boost mode: runs gaze alongside any base model (35/65 weight, veto at 0.38)
- Calibration UI: fullscreen overlay, auto-advance, progress ring
- Frontend: GAZE toggle, Calibrate button, gaze pointer dot on canvas
- Bumped capture resolution to 640x480 @ JPEG 0.75
- Dockerfile: added git, CPU-only torch for HF Space deployment

Dockerfile CHANGED
@@ -7,7 +7,14 @@ ENV PYTHONUNBUFFERED=1
7
 
8
  WORKDIR /app
9
 
10
- RUN apt-get update && apt-get install -y --no-install-recommends libglib2.0-0 libsm6 libxrender1 libxext6 libxcb1 libgl1 libgomp1 ffmpeg libavcodec-dev libavformat-dev libavutil-dev libswscale-dev libavdevice-dev libopus-dev libvpx-dev libsrtp2-dev build-essential nodejs npm && rm -rf /var/lib/apt/lists/*
 
 
 
 
 
 
 
11
 
12
  COPY requirements.txt ./
13
  RUN pip install --no-cache-dir -r requirements.txt
 
7
 
8
  WORKDIR /app
9
 
10
+ RUN apt-get update && apt-get install -y --no-install-recommends \
11
+ libglib2.0-0 libsm6 libxrender1 libxext6 libxcb1 libgl1 libgomp1 \
12
+ ffmpeg libavcodec-dev libavformat-dev libavutil-dev libswscale-dev \
13
+ libavdevice-dev libopus-dev libvpx-dev libsrtp2-dev \
14
+ build-essential nodejs npm git \
15
+ && rm -rf /var/lib/apt/lists/*
16
+
17
+ RUN pip install --no-cache-dir torch torchvision --index-url https://download.pytorch.org/whl/cpu
18
 
19
  COPY requirements.txt ./
20
  RUN pip install --no-cache-dir -r requirements.txt
README.md CHANGED
@@ -1,10 +1,94 @@
1
  ---
2
- title: IntegrationTest
3
- emoji: 📚
4
  colorFrom: indigo
5
  colorTo: purple
6
  sdk: docker
7
  pinned: false
8
  ---
9
 
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: FocusGuard
 
3
  colorFrom: indigo
4
  colorTo: purple
5
  sdk: docker
6
  pinned: false
7
  ---
8
 
9
+ # FocusGuard - Real-Time Focus Detection
10
+
11
+ A web app that monitors whether you're focused on your screen using your webcam. Combines head pose estimation, eye behaviour analysis, and deep learning gaze tracking to detect attention in real time.
12
+
13
+ ## How It Works
14
+
15
+ 1. **Open the app** and click **Start** - your webcam feed appears with a face mesh overlay.
16
+ 2. **Pick a model** from the selector bar (Geometric, XGBoost, L2CS, etc.).
17
+ 3. The system analyses each frame and shows **FOCUSED** or **NOT FOCUSED** with a confidence score.
18
+ 4. A timeline tracks your focus over time. Session history is saved for review.
19
+
20
+ ## Models
21
+
22
+ | Model | What it uses | Best for |
23
+ |-------|-------------|----------|
24
+ | **Geometric** | Head pose angles + eye aspect ratio (EAR) | Fast, no ML needed |
25
+ | **XGBoost** | Trained classifier on head/eye features | Balanced accuracy/speed |
26
+ | **MLP** | Neural network on same features | Higher accuracy |
27
+ | **Hybrid** | Weighted MLP + Geometric ensemble | Best head-pose accuracy |
28
+ | **L2CS** | Deep gaze estimation (ResNet50) | Detects eye-only gaze shifts |
29
+
30
+ ## L2CS Gaze Tracking
31
+
32
+ L2CS-Net predicts where your eyes are looking, not just where your head is pointed. This catches the scenario where your head faces the screen but your eyes wander.
33
+
34
+ ### Standalone mode
35
+ Select **L2CS** as the model - it handles everything.
36
+
37
+ ### Boost mode
38
+ Select any other model, then click the **GAZE** toggle. L2CS runs alongside the base model:
39
+ - Base model handles head pose and eye openness (35% weight)
40
+ - L2CS handles gaze direction (65% weight)
41
+ - If L2CS detects gaze is clearly off-screen, it **vetoes** the base model regardless of score
42
+
43
+ ### Calibration
44
+ After enabling L2CS or Gaze Boost, click **Calibrate** while a session is running:
45
+ 1. A fullscreen overlay shows 9 target dots (3x3 grid)
46
+ 2. Look at each dot as the progress ring fills
47
+ 3. The first dot (centre) sets your baseline gaze offset
48
+ 4. After all 9 points, a polynomial model maps your gaze angles to screen coordinates
49
+ 5. A cyan tracking dot appears on the video showing where you're looking
50
+
51
+ ## Tech Stack
52
+
53
+ - **Backend**: FastAPI + WebSocket, Python 3.10
54
+ - **Frontend**: React + Vite
55
+ - **Face detection**: MediaPipe Face Landmarker (478 landmarks)
56
+ - **Gaze estimation**: L2CS-Net (ResNet50, Gaze360 weights)
57
+ - **ML models**: XGBoost, PyTorch MLP
58
+ - **Deployment**: Docker on Hugging Face Spaces
59
+
60
+ ## Running Locally
61
+
62
+ ```bash
63
+ # install Python deps
64
+ pip install -r requirements.txt
65
+
66
+ # install frontend deps and build
67
+ npm install && npm run build
68
+
69
+ # start the server
70
+ uvicorn main:app --port 8000
71
+ ```
72
+
73
+ Open `http://localhost:8000` in your browser.
74
+
75
+ ## Project Structure
76
+
77
+ ```
78
+ main.py # FastAPI app, WebSocket handler, API endpoints
79
+ ui/pipeline.py # All focus detection pipelines (Geometric, MLP, XGBoost, Hybrid, L2CS)
80
+ models/
81
+ face_mesh.py # MediaPipe face landmark detector
82
+ head_pose.py # Head pose estimation from landmarks
83
+ eye_scorer.py # EAR/eye behaviour scoring
84
+ gaze_calibration.py # 9-point polynomial gaze calibration
85
+ gaze_eye_fusion.py # Fuses calibrated gaze with eye openness
86
+ L2CS-Net/ # In-tree L2CS-Net repo with Gaze360 weights
87
+ src/
88
+ components/
89
+ FocusPageLocal.jsx # Main focus page (camera, controls, model selector)
90
+ CalibrationOverlay.jsx # Fullscreen calibration UI
91
+ utils/
92
+ VideoManagerLocal.js # WebSocket client, frame capture, canvas rendering
93
+ Dockerfile # Docker build for HF Spaces
94
+ ```
download_l2cs_weights.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Downloads L2CS-Net Gaze360 weights into checkpoints/
3
+
4
+ import os
5
+ import sys
6
+
7
+ CHECKPOINTS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "checkpoints")
8
+ DEST = os.path.join(CHECKPOINTS_DIR, "L2CSNet_gaze360.pkl")
9
+ GDRIVE_ID = "1dL2Jokb19_SBSHAhKHOxJsmYs5-GoyLo"
10
+
11
+
12
+ def main():
13
+ if os.path.isfile(DEST):
14
+ print(f"[OK] Weights already at {DEST}")
15
+ return
16
+
17
+ try:
18
+ import gdown
19
+ except ImportError:
20
+ print("gdown not installed. Run: pip install gdown")
21
+ sys.exit(1)
22
+
23
+ os.makedirs(CHECKPOINTS_DIR, exist_ok=True)
24
+ print(f"Downloading L2CS-Net weights to {DEST} ...")
25
+ gdown.download(f"https://drive.google.com/uc?id={GDRIVE_ID}", DEST, quiet=False)
26
+
27
+ if os.path.isfile(DEST):
28
+ print(f"[OK] Downloaded ({os.path.getsize(DEST) / 1024 / 1024:.1f} MB)")
29
+ else:
30
+ print("[ERR] Download failed. Manual download:")
31
+ print(" https://drive.google.com/drive/folders/17p6ORr-JQJcw-eYtG2WGNiuS_qVKwdWd")
32
+ print(f" Place L2CSNet_gaze360.pkl in {CHECKPOINTS_DIR}/")
33
+ sys.exit(1)
34
+
35
+
36
+ if __name__ == "__main__":
37
+ main()
main.py CHANGED
@@ -22,7 +22,10 @@ from aiortc import RTCPeerConnection, RTCSessionDescription, VideoStreamTrack
22
  from av import VideoFrame
23
 
24
  from mediapipe.tasks.python.vision import FaceLandmarksConnections
25
- from ui.pipeline import FaceMeshPipeline, MLPPipeline, HybridFocusPipeline, XGBoostPipeline
 
 
 
26
  from models.face_mesh import FaceMeshDetector
27
 
28
  # ================ FACE MESH DRAWING (server-side, for WebRTC) ================
@@ -164,6 +167,7 @@ app.add_middleware(
164
  db_path = "focus_guard.db"
165
  pcs = set()
166
  _cached_model_name = "mlp" # in-memory cache, updated via /api/settings
 
167
 
168
  async def _wait_for_ice_gathering(pc: RTCPeerConnection):
169
  if pc.iceGatheringState == "complete":
@@ -243,6 +247,7 @@ class SettingsUpdate(BaseModel):
243
  notification_threshold: Optional[int] = None
244
  frame_rate: Optional[int] = None
245
  model_name: Optional[str] = None
 
246
 
247
  class VideoTransformTrack(VideoStreamTrack):
248
  def __init__(self, track, session_id: int, get_channel: Callable[[], Any]):
@@ -270,6 +275,8 @@ class VideoTransformTrack(VideoStreamTrack):
270
  self.last_inference_time = now
271
 
272
  model_name = _cached_model_name
 
 
273
  if model_name not in pipelines or pipelines.get(model_name) is None:
274
  model_name = 'mlp'
275
  active_pipeline = pipelines.get(model_name)
@@ -455,6 +462,7 @@ pipelines = {
455
  "mlp": None,
456
  "hybrid": None,
457
  "xgboost": None,
 
458
  }
459
 
460
  # Thread pool for CPU-bound inference so the event loop stays responsive.
@@ -464,14 +472,81 @@ _inference_executor = concurrent.futures.ThreadPoolExecutor(
464
  )
465
  # One lock per pipeline so shared state (TemporalTracker, etc.) is not corrupted when
466
  # multiple frames are processed in parallel by the thread pool.
467
- _pipeline_locks = {name: threading.Lock() for name in ("geometric", "mlp", "hybrid", "xgboost")}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
468
 
469
 
470
- def _process_frame_safe(pipeline, frame, model_name: str):
471
- """Run process_frame in executor with per-pipeline lock."""
472
  with _pipeline_locks[model_name]:
473
  return pipeline.process_frame(frame)
474
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
475
  @app.on_event("startup")
476
  async def startup_event():
477
  global pipelines, _cached_model_name
@@ -509,6 +584,11 @@ async def startup_event():
509
  except Exception as e:
510
  print(f"[ERR] Failed to load XGBoostPipeline: {e}")
511
 
 
 
 
 
 
512
  @app.on_event("shutdown")
513
  async def shutdown_event():
514
  _inference_executor.shutdown(wait=False)
@@ -579,14 +659,19 @@ async def webrtc_offer(offer: dict):
579
 
580
  @app.websocket("/ws/video")
581
  async def websocket_endpoint(websocket: WebSocket):
 
 
 
582
  await websocket.accept()
583
  session_id = None
584
  frame_count = 0
585
  running = True
586
  event_buffer = _EventBuffer(flush_interval=2.0)
587
 
 
 
 
588
  # Latest frame slot — only the most recent frame is kept, older ones are dropped.
589
- # Using a dict so nested functions can mutate without nonlocal issues.
590
  _slot = {"frame": None}
591
  _frame_ready = asyncio.Event()
592
 
@@ -617,7 +702,6 @@ async def websocket_endpoint(websocket: WebSocket):
617
  data = json.loads(text)
618
 
619
  if data["type"] == "frame":
620
- # Legacy base64 path (fallback)
621
  _slot["frame"] = base64.b64decode(data["image"])
622
  _frame_ready.set()
623
 
@@ -636,6 +720,47 @@ async def websocket_endpoint(websocket: WebSocket):
636
  if summary:
637
  await websocket.send_json({"type": "session_ended", "summary": summary})
638
  session_id = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
639
  except WebSocketDisconnect:
640
  running = False
641
  _frame_ready.set()
@@ -654,7 +779,6 @@ async def websocket_endpoint(websocket: WebSocket):
654
  if not running:
655
  return
656
 
657
- # Grab latest frame and clear slot
658
  raw = _slot["frame"]
659
  _slot["frame"] = None
660
  if raw is None:
@@ -667,38 +791,87 @@ async def websocket_endpoint(websocket: WebSocket):
667
  continue
668
  frame = cv2.resize(frame, (640, 480))
669
 
670
- model_name = _cached_model_name
 
 
 
 
 
 
 
 
 
 
 
671
  if model_name not in pipelines or pipelines.get(model_name) is None:
672
  model_name = "mlp"
673
  active_pipeline = pipelines.get(model_name)
674
 
 
 
 
 
 
 
 
 
675
  landmarks_list = None
 
676
  if active_pipeline is not None:
677
- out = await loop.run_in_executor(
678
- _inference_executor,
679
- _process_frame_safe,
680
- active_pipeline,
681
- frame,
682
- model_name,
683
- )
 
 
 
 
 
 
 
 
 
684
  is_focused = out["is_focused"]
685
  confidence = out.get("mlp_prob", out.get("raw_score", 0.0))
686
 
687
  lm = out.get("landmarks")
688
  if lm is not None:
689
- # Send all 478 landmarks as flat array for tessellation drawing
690
  landmarks_list = [
691
  [round(float(lm[i, 0]), 3), round(float(lm[i, 1]), 3)]
692
  for i in range(lm.shape[0])
693
  ]
694
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
695
  if session_id:
696
- event_buffer.add(session_id, is_focused, confidence, {
697
  "s_face": out.get("s_face", 0.0),
698
  "s_eye": out.get("s_eye", 0.0),
699
  "mar": out.get("mar", 0.0),
700
  "model": model_name,
701
- })
 
702
  else:
703
  is_focused = False
704
  confidence = 0.0
@@ -710,8 +883,7 @@ async def websocket_endpoint(websocket: WebSocket):
710
  "model": model_name,
711
  "fc": frame_count,
712
  }
713
- if active_pipeline is not None:
714
- # Send detailed metrics for HUD
715
  if out.get("yaw") is not None:
716
  resp["yaw"] = round(out["yaw"], 1)
717
  resp["pitch"] = round(out["pitch"], 1)
@@ -720,6 +892,24 @@ async def websocket_endpoint(websocket: WebSocket):
720
  resp["mar"] = round(out["mar"], 3)
721
  resp["sf"] = round(out.get("s_face", 0), 3)
722
  resp["se"] = round(out.get("s_eye", 0), 3)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
723
  if landmarks_list is not None:
724
  resp["lm"] = landmarks_list
725
  await websocket.send_json(resp)
@@ -852,8 +1042,9 @@ async def get_settings():
852
  db.row_factory = aiosqlite.Row
853
  cursor = await db.execute("SELECT * FROM user_settings WHERE id = 1")
854
  row = await cursor.fetchone()
855
- if row: return dict(row)
856
- else: return {'sensitivity': 6, 'notification_enabled': True, 'notification_threshold': 30, 'frame_rate': 30, 'model_name': 'mlp'}
 
857
 
858
  @app.put("/api/settings")
859
  async def update_settings(settings: SettingsUpdate):
@@ -878,12 +1069,28 @@ async def update_settings(settings: SettingsUpdate):
878
  if settings.frame_rate is not None:
879
  updates.append("frame_rate = ?")
880
  params.append(max(5, min(60, settings.frame_rate)))
881
- if settings.model_name is not None and settings.model_name in pipelines and pipelines[settings.model_name] is not None:
 
 
 
 
 
 
 
882
  updates.append("model_name = ?")
883
  params.append(settings.model_name)
884
  global _cached_model_name
885
  _cached_model_name = settings.model_name
886
 
 
 
 
 
 
 
 
 
 
887
  if updates:
888
  query = f"UPDATE user_settings SET {', '.join(updates)} WHERE id = 1"
889
  await db.execute(query, params)
@@ -919,15 +1126,55 @@ async def get_stats_summary():
919
 
920
  @app.get("/api/models")
921
  async def get_available_models():
922
- """Return list of loaded model names and which is currently active."""
923
- available = [name for name, p in pipelines.items() if p is not None]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
924
  async with aiosqlite.connect(db_path) as db:
925
  cursor = await db.execute("SELECT model_name FROM user_settings WHERE id = 1")
926
  row = await cursor.fetchone()
927
  current = row[0] if row else "mlp"
928
  if current not in available and available:
929
  current = available[0]
930
- return {"available": available, "current": current}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
931
 
932
  @app.get("/api/mesh-topology")
933
  async def get_mesh_topology():
 
22
  from av import VideoFrame
23
 
24
  from mediapipe.tasks.python.vision import FaceLandmarksConnections
25
+ from ui.pipeline import (
26
+ FaceMeshPipeline, MLPPipeline, HybridFocusPipeline, XGBoostPipeline,
27
+ L2CSPipeline, is_l2cs_weights_available,
28
+ )
29
  from models.face_mesh import FaceMeshDetector
30
 
31
  # ================ FACE MESH DRAWING (server-side, for WebRTC) ================
 
167
  db_path = "focus_guard.db"
168
  pcs = set()
169
  _cached_model_name = "mlp" # in-memory cache, updated via /api/settings
170
+ _l2cs_boost_enabled = False # when True, L2CS runs alongside the base model
171
 
172
  async def _wait_for_ice_gathering(pc: RTCPeerConnection):
173
  if pc.iceGatheringState == "complete":
 
247
  notification_threshold: Optional[int] = None
248
  frame_rate: Optional[int] = None
249
  model_name: Optional[str] = None
250
+ l2cs_boost: Optional[bool] = None
251
 
252
  class VideoTransformTrack(VideoStreamTrack):
253
  def __init__(self, track, session_id: int, get_channel: Callable[[], Any]):
 
275
  self.last_inference_time = now
276
 
277
  model_name = _cached_model_name
278
+ if model_name == "l2cs" and pipelines.get("l2cs") is None:
279
+ _ensure_l2cs()
280
  if model_name not in pipelines or pipelines.get(model_name) is None:
281
  model_name = 'mlp'
282
  active_pipeline = pipelines.get(model_name)
 
462
  "mlp": None,
463
  "hybrid": None,
464
  "xgboost": None,
465
+ "l2cs": None,
466
  }
467
 
468
  # Thread pool for CPU-bound inference so the event loop stays responsive.
 
472
  )
473
  # One lock per pipeline so shared state (TemporalTracker, etc.) is not corrupted when
474
  # multiple frames are processed in parallel by the thread pool.
475
+ _pipeline_locks = {name: threading.Lock() for name in ("geometric", "mlp", "hybrid", "xgboost", "l2cs")}
476
+
477
+ _l2cs_load_lock = threading.Lock()
478
+ _l2cs_error: str | None = None
479
+
480
+
481
+ def _ensure_l2cs():
482
+ # lazy-load L2CS on first use, double-checked locking
483
+ global _l2cs_error
484
+ if pipelines["l2cs"] is not None:
485
+ return True
486
+ with _l2cs_load_lock:
487
+ if pipelines["l2cs"] is not None:
488
+ return True
489
+ if not is_l2cs_weights_available():
490
+ _l2cs_error = "Weights not found"
491
+ return False
492
+ try:
493
+ pipelines["l2cs"] = L2CSPipeline()
494
+ _l2cs_error = None
495
+ print("[OK] L2CSPipeline lazy-loaded")
496
+ return True
497
+ except Exception as e:
498
+ _l2cs_error = str(e)
499
+ print(f"[ERR] L2CS lazy-load failed: {e}")
500
+ return False
501
 
502
 
503
+ def _process_frame_safe(pipeline, frame, model_name):
 
504
  with _pipeline_locks[model_name]:
505
  return pipeline.process_frame(frame)
506
 
507
+
508
+ _BOOST_BASE_W = 0.35
509
+ _BOOST_L2CS_W = 0.65
510
+ _BOOST_VETO = 0.38 # L2CS below this -> forced not-focused
511
+
512
+
513
+ def _process_frame_with_l2cs_boost(base_pipeline, frame, base_model_name):
514
+ # run base model
515
+ with _pipeline_locks[base_model_name]:
516
+ base_out = base_pipeline.process_frame(frame)
517
+
518
+ l2cs_pipe = pipelines.get("l2cs")
519
+ if l2cs_pipe is None:
520
+ base_out["boost_active"] = False
521
+ return base_out
522
+
523
+ # run L2CS
524
+ with _pipeline_locks["l2cs"]:
525
+ l2cs_out = l2cs_pipe.process_frame(frame)
526
+
527
+ base_score = base_out.get("mlp_prob", base_out.get("raw_score", 0.0))
528
+ l2cs_score = l2cs_out.get("raw_score", 0.0)
529
+
530
+ # veto: gaze clearly off-screen overrides base model
531
+ if l2cs_score < _BOOST_VETO:
532
+ fused_score = l2cs_score * 0.8
533
+ is_focused = False
534
+ else:
535
+ fused_score = _BOOST_BASE_W * base_score + _BOOST_L2CS_W * l2cs_score
536
+ is_focused = fused_score >= 0.52
537
+
538
+ base_out["raw_score"] = fused_score
539
+ base_out["is_focused"] = is_focused
540
+ base_out["boost_active"] = True
541
+ base_out["base_score"] = round(base_score, 3)
542
+ base_out["l2cs_score"] = round(l2cs_score, 3)
543
+
544
+ if l2cs_out.get("gaze_yaw") is not None:
545
+ base_out["gaze_yaw"] = l2cs_out["gaze_yaw"]
546
+ base_out["gaze_pitch"] = l2cs_out["gaze_pitch"]
547
+
548
+ return base_out
549
+
550
  @app.on_event("startup")
551
  async def startup_event():
552
  global pipelines, _cached_model_name
 
584
  except Exception as e:
585
  print(f"[ERR] Failed to load XGBoostPipeline: {e}")
586
 
587
+ if is_l2cs_weights_available():
588
+ print("[OK] L2CS weights found — pipeline will be lazy-loaded on first use")
589
+ else:
590
+ print("[WARN] L2CS weights not found — l2cs model unavailable")
591
+
592
  @app.on_event("shutdown")
593
  async def shutdown_event():
594
  _inference_executor.shutdown(wait=False)
 
659
 
660
  @app.websocket("/ws/video")
661
  async def websocket_endpoint(websocket: WebSocket):
662
+ from models.gaze_calibration import GazeCalibration
663
+ from models.gaze_eye_fusion import GazeEyeFusion
664
+
665
  await websocket.accept()
666
  session_id = None
667
  frame_count = 0
668
  running = True
669
  event_buffer = _EventBuffer(flush_interval=2.0)
670
 
671
+ # Calibration state (per-connection)
672
+ _cal: dict = {"cal": None, "collecting": False, "fusion": None}
673
+
674
  # Latest frame slot — only the most recent frame is kept, older ones are dropped.
 
675
  _slot = {"frame": None}
676
  _frame_ready = asyncio.Event()
677
 
 
702
  data = json.loads(text)
703
 
704
  if data["type"] == "frame":
 
705
  _slot["frame"] = base64.b64decode(data["image"])
706
  _frame_ready.set()
707
 
 
720
  if summary:
721
  await websocket.send_json({"type": "session_ended", "summary": summary})
722
  session_id = None
723
+
724
+ # ---- Calibration commands ----
725
+ elif data["type"] == "calibration_start":
726
+ loop = asyncio.get_event_loop()
727
+ await loop.run_in_executor(_inference_executor, _ensure_l2cs)
728
+ _cal["cal"] = GazeCalibration()
729
+ _cal["collecting"] = True
730
+ _cal["fusion"] = None
731
+ cal = _cal["cal"]
732
+ await websocket.send_json({
733
+ "type": "calibration_started",
734
+ "num_points": cal.num_points,
735
+ "target": list(cal.current_target),
736
+ "index": cal.current_index,
737
+ })
738
+
739
+ elif data["type"] == "calibration_next":
740
+ cal = _cal.get("cal")
741
+ if cal is not None:
742
+ more = cal.advance()
743
+ if more:
744
+ await websocket.send_json({
745
+ "type": "calibration_point",
746
+ "target": list(cal.current_target),
747
+ "index": cal.current_index,
748
+ })
749
+ else:
750
+ _cal["collecting"] = False
751
+ ok = cal.fit()
752
+ if ok:
753
+ _cal["fusion"] = GazeEyeFusion(cal)
754
+ await websocket.send_json({"type": "calibration_done", "success": True})
755
+ else:
756
+ await websocket.send_json({"type": "calibration_done", "success": False, "error": "Not enough samples"})
757
+
758
+ elif data["type"] == "calibration_cancel":
759
+ _cal["cal"] = None
760
+ _cal["collecting"] = False
761
+ _cal["fusion"] = None
762
+ await websocket.send_json({"type": "calibration_cancelled"})
763
+
764
  except WebSocketDisconnect:
765
  running = False
766
  _frame_ready.set()
 
779
  if not running:
780
  return
781
 
 
782
  raw = _slot["frame"]
783
  _slot["frame"] = None
784
  if raw is None:
 
791
  continue
792
  frame = cv2.resize(frame, (640, 480))
793
 
794
+ # During calibration collection, always use L2CS
795
+ collecting = _cal.get("collecting", False)
796
+ if collecting:
797
+ if pipelines.get("l2cs") is None:
798
+ await loop.run_in_executor(_inference_executor, _ensure_l2cs)
799
+ use_model = "l2cs" if pipelines.get("l2cs") is not None else _cached_model_name
800
+ else:
801
+ use_model = _cached_model_name
802
+
803
+ model_name = use_model
804
+ if model_name == "l2cs" and pipelines.get("l2cs") is None:
805
+ await loop.run_in_executor(_inference_executor, _ensure_l2cs)
806
  if model_name not in pipelines or pipelines.get(model_name) is None:
807
  model_name = "mlp"
808
  active_pipeline = pipelines.get(model_name)
809
 
810
+ # L2CS boost: run L2CS alongside base model
811
+ use_boost = (
812
+ _l2cs_boost_enabled
813
+ and model_name != "l2cs"
814
+ and pipelines.get("l2cs") is not None
815
+ and not collecting
816
+ )
817
+
818
  landmarks_list = None
819
+ out = None
820
  if active_pipeline is not None:
821
+ if use_boost:
822
+ out = await loop.run_in_executor(
823
+ _inference_executor,
824
+ _process_frame_with_l2cs_boost,
825
+ active_pipeline,
826
+ frame,
827
+ model_name,
828
+ )
829
+ else:
830
+ out = await loop.run_in_executor(
831
+ _inference_executor,
832
+ _process_frame_safe,
833
+ active_pipeline,
834
+ frame,
835
+ model_name,
836
+ )
837
  is_focused = out["is_focused"]
838
  confidence = out.get("mlp_prob", out.get("raw_score", 0.0))
839
 
840
  lm = out.get("landmarks")
841
  if lm is not None:
 
842
  landmarks_list = [
843
  [round(float(lm[i, 0]), 3), round(float(lm[i, 1]), 3)]
844
  for i in range(lm.shape[0])
845
  ]
846
 
847
+ # Calibration sample collection (L2CS gaze angles)
848
+ if collecting and _cal.get("cal") is not None:
849
+ pipe_yaw = out.get("gaze_yaw")
850
+ pipe_pitch = out.get("gaze_pitch")
851
+ if pipe_yaw is not None and pipe_pitch is not None:
852
+ _cal["cal"].collect_sample(pipe_yaw, pipe_pitch)
853
+
854
+ # Gaze fusion (when L2CS active + calibration fitted)
855
+ fusion = _cal.get("fusion")
856
+ if (
857
+ fusion is not None
858
+ and model_name == "l2cs"
859
+ and out.get("gaze_yaw") is not None
860
+ ):
861
+ fuse = fusion.update(
862
+ out["gaze_yaw"], out["gaze_pitch"], lm
863
+ )
864
+ is_focused = fuse["focused"]
865
+ confidence = fuse["focus_score"]
866
+
867
  if session_id:
868
+ metadata = {
869
  "s_face": out.get("s_face", 0.0),
870
  "s_eye": out.get("s_eye", 0.0),
871
  "mar": out.get("mar", 0.0),
872
  "model": model_name,
873
+ }
874
+ event_buffer.add(session_id, is_focused, confidence, metadata)
875
  else:
876
  is_focused = False
877
  confidence = 0.0
 
883
  "model": model_name,
884
  "fc": frame_count,
885
  }
886
+ if out is not None:
 
887
  if out.get("yaw") is not None:
888
  resp["yaw"] = round(out["yaw"], 1)
889
  resp["pitch"] = round(out["pitch"], 1)
 
892
  resp["mar"] = round(out["mar"], 3)
893
  resp["sf"] = round(out.get("s_face", 0), 3)
894
  resp["se"] = round(out.get("s_eye", 0), 3)
895
+
896
+ # Gaze fusion fields (L2CS standalone or boost mode)
897
+ fusion = _cal.get("fusion")
898
+ has_gaze = out.get("gaze_yaw") is not None
899
+ if fusion is not None and has_gaze and (model_name == "l2cs" or use_boost):
900
+ fuse = fusion.update(out["gaze_yaw"], out["gaze_pitch"], out.get("landmarks"))
901
+ resp["gaze_x"] = fuse["gaze_x"]
902
+ resp["gaze_y"] = fuse["gaze_y"]
903
+ resp["on_screen"] = fuse["on_screen"]
904
+ if model_name == "l2cs":
905
+ resp["focused"] = fuse["focused"]
906
+ resp["confidence"] = round(fuse["focus_score"], 3)
907
+
908
+ if out.get("boost_active"):
909
+ resp["boost"] = True
910
+ resp["base_score"] = out.get("base_score", 0)
911
+ resp["l2cs_score"] = out.get("l2cs_score", 0)
912
+
913
  if landmarks_list is not None:
914
  resp["lm"] = landmarks_list
915
  await websocket.send_json(resp)
 
1042
  db.row_factory = aiosqlite.Row
1043
  cursor = await db.execute("SELECT * FROM user_settings WHERE id = 1")
1044
  row = await cursor.fetchone()
1045
+ result = dict(row) if row else {'sensitivity': 6, 'notification_enabled': True, 'notification_threshold': 30, 'frame_rate': 30, 'model_name': 'mlp'}
1046
+ result['l2cs_boost'] = _l2cs_boost_enabled
1047
+ return result
1048
 
1049
  @app.put("/api/settings")
1050
  async def update_settings(settings: SettingsUpdate):
 
1069
  if settings.frame_rate is not None:
1070
  updates.append("frame_rate = ?")
1071
  params.append(max(5, min(60, settings.frame_rate)))
1072
+ if settings.model_name is not None and settings.model_name in pipelines:
1073
+ if settings.model_name == "l2cs":
1074
+ loop = asyncio.get_event_loop()
1075
+ loaded = await loop.run_in_executor(_inference_executor, _ensure_l2cs)
1076
+ if not loaded:
1077
+ raise HTTPException(status_code=400, detail=f"L2CS model unavailable: {_l2cs_error}")
1078
+ elif pipelines[settings.model_name] is None:
1079
+ raise HTTPException(status_code=400, detail=f"Model '{settings.model_name}' not loaded")
1080
  updates.append("model_name = ?")
1081
  params.append(settings.model_name)
1082
  global _cached_model_name
1083
  _cached_model_name = settings.model_name
1084
 
1085
+ if settings.l2cs_boost is not None:
1086
+ global _l2cs_boost_enabled
1087
+ if settings.l2cs_boost:
1088
+ loop = asyncio.get_event_loop()
1089
+ loaded = await loop.run_in_executor(_inference_executor, _ensure_l2cs)
1090
+ if not loaded:
1091
+ raise HTTPException(status_code=400, detail=f"L2CS boost unavailable: {_l2cs_error}")
1092
+ _l2cs_boost_enabled = settings.l2cs_boost
1093
+
1094
  if updates:
1095
  query = f"UPDATE user_settings SET {', '.join(updates)} WHERE id = 1"
1096
  await db.execute(query, params)
 
1126
 
1127
  @app.get("/api/models")
1128
  async def get_available_models():
1129
+ """Return model names, statuses, and which is currently active."""
1130
+ statuses = {}
1131
+ errors = {}
1132
+ available = []
1133
+ for name, p in pipelines.items():
1134
+ if name == "l2cs":
1135
+ if p is not None:
1136
+ statuses[name] = "ready"
1137
+ available.append(name)
1138
+ elif is_l2cs_weights_available():
1139
+ statuses[name] = "lazy"
1140
+ available.append(name)
1141
+ elif _l2cs_error:
1142
+ statuses[name] = "error"
1143
+ errors[name] = _l2cs_error
1144
+ else:
1145
+ statuses[name] = "unavailable"
1146
+ elif p is not None:
1147
+ statuses[name] = "ready"
1148
+ available.append(name)
1149
+ else:
1150
+ statuses[name] = "unavailable"
1151
  async with aiosqlite.connect(db_path) as db:
1152
  cursor = await db.execute("SELECT model_name FROM user_settings WHERE id = 1")
1153
  row = await cursor.fetchone()
1154
  current = row[0] if row else "mlp"
1155
  if current not in available and available:
1156
  current = available[0]
1157
+ l2cs_boost_available = (
1158
+ statuses.get("l2cs") in ("ready", "lazy") and current != "l2cs"
1159
+ )
1160
+ return {
1161
+ "available": available,
1162
+ "current": current,
1163
+ "statuses": statuses,
1164
+ "errors": errors,
1165
+ "l2cs_boost": _l2cs_boost_enabled,
1166
+ "l2cs_boost_available": l2cs_boost_available,
1167
+ }
1168
+
1169
+ @app.get("/api/l2cs/status")
1170
+ async def l2cs_status():
1171
+ """L2CS-specific status: weights available, loaded, and calibration info."""
1172
+ loaded = pipelines.get("l2cs") is not None
1173
+ return {
1174
+ "weights_available": is_l2cs_weights_available(),
1175
+ "loaded": loaded,
1176
+ "error": _l2cs_error,
1177
+ }
1178
 
1179
  @app.get("/api/mesh-topology")
1180
  async def get_mesh_topology():
models/L2CS-Net/.gitignore ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ignore the test data - sensitive
2
+ datasets/
3
+ evaluation/
4
+ output/
5
+
6
+ # Ignore debugging configurations
7
+ /.vscode
8
+
9
+ # Byte-compiled / optimized / DLL files
10
+ __pycache__/
11
+ *.py[cod]
12
+ *$py.class
13
+
14
+ # C extensions
15
+ *.so
16
+
17
+ # Distribution / packaging
18
+ .Python
19
+ build/
20
+ develop-eggs/
21
+ dist/
22
+ downloads/
23
+ eggs/
24
+ .eggs/
25
+ lib/
26
+ lib64/
27
+ parts/
28
+ sdist/
29
+ var/
30
+ wheels/
31
+ pip-wheel-metadata/
32
+ share/python-wheels/
33
+ *.egg-info/
34
+ .installed.cfg
35
+ *.egg
36
+ MANIFEST
37
+
38
+ # PyInstaller
39
+ # Usually these files are written by a python script from a template
40
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
41
+ *.manifest
42
+ *.spec
43
+
44
+ # Installer logs
45
+ pip-log.txt
46
+ pip-delete-this-directory.txt
47
+
48
+ # Unit test / coverage reports
49
+ htmlcov/
50
+ .tox/
51
+ .nox/
52
+ .coverage
53
+ .coverage.*
54
+ .cache
55
+ nosetests.xml
56
+ coverage.xml
57
+ *.cover
58
+ *.py,cover
59
+ .hypothesis/
60
+ .pytest_cache/
61
+
62
+ # Translations
63
+ *.mo
64
+ *.pot
65
+
66
+ # Django stuff:
67
+ *.log
68
+ local_settings.py
69
+ db.sqlite3
70
+ db.sqlite3-journal
71
+
72
+ # Flask stuff:
73
+ instance/
74
+ .webassets-cache
75
+
76
+ # Scrapy stuff:
77
+ .scrapy
78
+
79
+ # Sphinx documentation
80
+ docs/_build/
81
+
82
+ # PyBuilder
83
+ target/
84
+
85
+ # Jupyter Notebook
86
+ .ipynb_checkpoints
87
+
88
+ # IPython
89
+ profile_default/
90
+ ipython_config.py
91
+
92
+ # pyenv
93
+ .python-version
94
+
95
+ # pipenv
96
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
97
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
98
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
99
+ # install all needed dependencies.
100
+ #Pipfile.lock
101
+
102
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
103
+ __pypackages__/
104
+
105
+ # Celery stuff
106
+ celerybeat-schedule
107
+ celerybeat.pid
108
+
109
+ # SageMath parsed files
110
+ *.sage.py
111
+
112
+ # Environments
113
+ .env
114
+ .venv
115
+ env/
116
+ venv/
117
+ ENV/
118
+ env.bak/
119
+ venv.bak/
120
+
121
+ # Spyder project settings
122
+ .spyderproject
123
+ .spyproject
124
+
125
+ # Rope project settings
126
+ .ropeproject
127
+
128
+ # mkdocs documentation
129
+ /site
130
+
131
+ # mypy
132
+ .mypy_cache/
133
+ .dmypy.json
134
+ dmypy.json
135
+
136
+ # Pyre type checker
137
+ .pyre/
138
+
139
+ # Ignore other files
140
+ my.secrets
models/L2CS-Net/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2022 Ahmed Abdelrahman
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
models/L2CS-Net/README.md ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+
4
+ <p align="center">
5
+ <img src="https://github.com/Ahmednull/Storage/blob/main/gaze.gif" alt="animated" />
6
+ </p>
7
+
8
+
9
+ ___
10
+
11
+ # L2CS-Net
12
+
13
+ The official PyTorch implementation of L2CS-Net for gaze estimation and tracking.
14
+
15
+ ## Installation
16
+ <img src="https://img.shields.io/badge/python%20-%2314354C.svg?&style=for-the-badge&logo=python&logoColor=white"/> <img src="https://img.shields.io/badge/PyTorch%20-%23EE4C2C.svg?&style=for-the-badge&logo=PyTorch&logoColor=white" />
17
+
18
+ Install package with the following:
19
+
20
+ ```
21
+ pip install git+https://github.com/Ahmednull/L2CS-Net.git@main
22
+ ```
23
+
24
+ Or, you can git clone the repo and install with the following:
25
+
26
+ ```
27
+ pip install [-e] .
28
+ ```
29
+
30
+ Now you should be able to import the package with the following command:
31
+
32
+ ```
33
+ $ python
34
+ >>> import l2cs
35
+ ```
36
+
37
+ ## Usage
38
+
39
+ Detect face and predict gaze from webcam
40
+
41
+ ```python
42
+ from l2cs import Pipeline, render
43
+ import cv2
44
+
45
+ gaze_pipeline = Pipeline(
46
+ weights=CWD / 'models' / 'L2CSNet_gaze360.pkl',
47
+ arch='ResNet50',
48
+ device=torch.device('cpu') # or 'gpu'
49
+ )
50
+
51
+ cap = cv2.VideoCapture(cam)
52
+ _, frame = cap.read()
53
+
54
+ # Process frame and visualize
55
+ results = gaze_pipeline.step(frame)
56
+ frame = render(frame, results)
57
+ ```
58
+
59
+ ## Demo
60
+ * Download the pre-trained models from [here](https://drive.google.com/drive/folders/17p6ORr-JQJcw-eYtG2WGNiuS_qVKwdWd?usp=sharing) and Store it to *models/*.
61
+ * Run:
62
+ ```
63
+ python demo.py \
64
+ --snapshot models/L2CSNet_gaze360.pkl \
65
+ --gpu 0 \
66
+ --cam 0 \
67
+ ```
68
+ This means the demo will run using *L2CSNet_gaze360.pkl* pretrained model
69
+
70
+ ## Community Contributions
71
+
72
+ - [Gaze Detection and Eye Tracking: A How-To Guide](https://blog.roboflow.com/gaze-direction-position/): Use L2CS-Net through a HTTP interface with the open source Roboflow Inference project.
73
+
74
+ ## MPIIGaze
75
+ We provide the code for train and test MPIIGaze dataset with leave-one-person-out evaluation.
76
+
77
+ ### Prepare datasets
78
+ * Download **MPIIFaceGaze dataset** from [here](https://www.mpi-inf.mpg.de/departments/computer-vision-and-machine-learning/research/gaze-based-human-computer-interaction/its-written-all-over-your-face-full-face-appearance-based-gaze-estimation).
79
+ * Apply data preprocessing from [here](http://phi-ai.buaa.edu.cn/Gazehub/3D-dataset/).
80
+ * Store the dataset to *datasets/MPIIFaceGaze*.
81
+
82
+ ### Train
83
+ ```
84
+ python train.py \
85
+ --dataset mpiigaze \
86
+ --snapshot output/snapshots \
87
+ --gpu 0 \
88
+ --num_epochs 50 \
89
+ --batch_size 16 \
90
+ --lr 0.00001 \
91
+ --alpha 1 \
92
+
93
+ ```
94
+ This means the code will perform leave-one-person-out training automatically and store the models to *output/snapshots*.
95
+
96
+ ### Test
97
+ ```
98
+ python test.py \
99
+ --dataset mpiigaze \
100
+ --snapshot output/snapshots/snapshot_folder \
101
+ --evalpath evaluation/L2CS-mpiigaze \
102
+ --gpu 0 \
103
+ ```
104
+ This means the code will perform leave-one-person-out testing automatically and store the results to *evaluation/L2CS-mpiigaze*.
105
+
106
+ To get the average leave-one-person-out accuracy use:
107
+ ```
108
+ python leave_one_out_eval.py \
109
+ --evalpath evaluation/L2CS-mpiigaze \
110
+ --respath evaluation/L2CS-mpiigaze \
111
+ ```
112
+ This means the code will take the evaluation path and outputs the leave-one-out gaze accuracy to the *evaluation/L2CS-mpiigaze*.
113
+
114
+ ## Gaze360
115
+ We provide the code for train and test Gaze360 dataset with train-val-test evaluation.
116
+
117
+ ### Prepare datasets
118
+ * Download **Gaze360 dataset** from [here](http://gaze360.csail.mit.edu/download.php).
119
+
120
+ * Apply data preprocessing from [here](http://phi-ai.buaa.edu.cn/Gazehub/3D-dataset/).
121
+
122
+ * Store the dataset to *datasets/Gaze360*.
123
+
124
+
125
+ ### Train
126
+ ```
127
+ python train.py \
128
+ --dataset gaze360 \
129
+ --snapshot output/snapshots \
130
+ --gpu 0 \
131
+ --num_epochs 50 \
132
+ --batch_size 16 \
133
+ --lr 0.00001 \
134
+ --alpha 1 \
135
+
136
+ ```
137
+ This means the code will perform training and store the models to *output/snapshots*.
138
+
139
+ ### Test
140
+ ```
141
+ python test.py \
142
+ --dataset gaze360 \
143
+ --snapshot output/snapshots/snapshot_folder \
144
+ --evalpath evaluation/L2CS-gaze360 \
145
+ --gpu 0 \
146
+ ```
147
+ This means the code will perform testing on snapshot_folder and store the results to *evaluation/L2CS-gaze360*.
148
+
models/L2CS-Net/demo.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import pathlib
3
+ import numpy as np
4
+ import cv2
5
+ import time
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from torch.autograd import Variable
10
+ from torchvision import transforms
11
+ import torch.backends.cudnn as cudnn
12
+ import torchvision
13
+
14
+ from PIL import Image
15
+ from PIL import Image, ImageOps
16
+
17
+ from face_detection import RetinaFace
18
+
19
+ from l2cs import select_device, draw_gaze, getArch, Pipeline, render
20
+
21
+ CWD = pathlib.Path.cwd()
22
+
23
+ def parse_args():
24
+ """Parse input arguments."""
25
+ parser = argparse.ArgumentParser(
26
+ description='Gaze evalution using model pretrained with L2CS-Net on Gaze360.')
27
+ parser.add_argument(
28
+ '--device',dest='device', help='Device to run model: cpu or gpu:0',
29
+ default="cpu", type=str)
30
+ parser.add_argument(
31
+ '--snapshot',dest='snapshot', help='Path of model snapshot.',
32
+ default='output/snapshots/L2CS-gaze360-_loader-180-4/_epoch_55.pkl', type=str)
33
+ parser.add_argument(
34
+ '--cam',dest='cam_id', help='Camera device id to use [0]',
35
+ default=0, type=int)
36
+ parser.add_argument(
37
+ '--arch',dest='arch',help='Network architecture, can be: ResNet18, ResNet34, ResNet50, ResNet101, ResNet152',
38
+ default='ResNet50', type=str)
39
+
40
+ args = parser.parse_args()
41
+ return args
42
+
43
+ if __name__ == '__main__':
44
+ args = parse_args()
45
+
46
+ cudnn.enabled = True
47
+ arch=args.arch
48
+ cam = args.cam_id
49
+ # snapshot_path = args.snapshot
50
+
51
+ gaze_pipeline = Pipeline(
52
+ weights=CWD / 'models' / 'L2CSNet_gaze360.pkl',
53
+ arch='ResNet50',
54
+ device = select_device(args.device, batch_size=1)
55
+ )
56
+
57
+ cap = cv2.VideoCapture(cam)
58
+
59
+ # Check if the webcam is opened correctly
60
+ if not cap.isOpened():
61
+ raise IOError("Cannot open webcam")
62
+
63
+ with torch.no_grad():
64
+ while True:
65
+
66
+ # Get frame
67
+ success, frame = cap.read()
68
+ start_fps = time.time()
69
+
70
+ if not success:
71
+ print("Failed to obtain frame")
72
+ time.sleep(0.1)
73
+
74
+ # Process frame
75
+ results = gaze_pipeline.step(frame)
76
+
77
+ # Visualize output
78
+ frame = render(frame, results)
79
+
80
+ myFPS = 1.0 / (time.time() - start_fps)
81
+ cv2.putText(frame, 'FPS: {:.1f}'.format(myFPS), (10, 20),cv2.FONT_HERSHEY_COMPLEX_SMALL, 1, (0, 255, 0), 1, cv2.LINE_AA)
82
+
83
+ cv2.imshow("Demo",frame)
84
+ if cv2.waitKey(1) & 0xFF == ord('q'):
85
+ break
86
+ success,frame = cap.read()
87
+
models/L2CS-Net/l2cs/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .utils import select_device, natural_keys, gazeto3d, angular, getArch
2
+ from .vis import draw_gaze, render
3
+ from .model import L2CS
4
+ from .pipeline import Pipeline
5
+ from .datasets import Gaze360, Mpiigaze
6
+
7
+ __all__ = [
8
+ # Classes
9
+ 'L2CS',
10
+ 'Pipeline',
11
+ 'Gaze360',
12
+ 'Mpiigaze',
13
+ # Utils
14
+ 'render',
15
+ 'select_device',
16
+ 'draw_gaze',
17
+ 'natural_keys',
18
+ 'gazeto3d',
19
+ 'angular',
20
+ 'getArch'
21
+ ]
models/L2CS-Net/l2cs/datasets.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import cv2
4
+
5
+
6
+ import torch
7
+ from torch.utils.data.dataset import Dataset
8
+ from torchvision import transforms
9
+ from PIL import Image, ImageFilter
10
+
11
+
12
+ class Gaze360(Dataset):
13
+ def __init__(self, path, root, transform, angle, binwidth, train=True):
14
+ self.transform = transform
15
+ self.root = root
16
+ self.orig_list_len = 0
17
+ self.angle = angle
18
+ if train==False:
19
+ angle=90
20
+ self.binwidth=binwidth
21
+ self.lines = []
22
+ if isinstance(path, list):
23
+ for i in path:
24
+ with open(i) as f:
25
+ print("here")
26
+ line = f.readlines()
27
+ line.pop(0)
28
+ self.lines.extend(line)
29
+ else:
30
+ with open(path) as f:
31
+ lines = f.readlines()
32
+ lines.pop(0)
33
+ self.orig_list_len = len(lines)
34
+ for line in lines:
35
+ gaze2d = line.strip().split(" ")[5]
36
+ label = np.array(gaze2d.split(",")).astype("float")
37
+ if abs((label[0]*180/np.pi)) <= angle and abs((label[1]*180/np.pi)) <= angle:
38
+ self.lines.append(line)
39
+
40
+
41
+ print("{} items removed from dataset that have an angle > {}".format(self.orig_list_len-len(self.lines), angle))
42
+
43
+ def __len__(self):
44
+ return len(self.lines)
45
+
46
+ def __getitem__(self, idx):
47
+ line = self.lines[idx]
48
+ line = line.strip().split(" ")
49
+
50
+ face = line[0]
51
+ lefteye = line[1]
52
+ righteye = line[2]
53
+ name = line[3]
54
+ gaze2d = line[5]
55
+ label = np.array(gaze2d.split(",")).astype("float")
56
+ label = torch.from_numpy(label).type(torch.FloatTensor)
57
+
58
+ pitch = label[0]* 180 / np.pi
59
+ yaw = label[1]* 180 / np.pi
60
+
61
+ img = Image.open(os.path.join(self.root, face))
62
+
63
+ # fimg = cv2.imread(os.path.join(self.root, face))
64
+ # fimg = cv2.resize(fimg, (448, 448))/255.0
65
+ # fimg = fimg.transpose(2, 0, 1)
66
+ # img=torch.from_numpy(fimg).type(torch.FloatTensor)
67
+
68
+ if self.transform:
69
+ img = self.transform(img)
70
+
71
+ # Bin values
72
+ bins = np.array(range(-1*self.angle, self.angle, self.binwidth))
73
+ binned_pose = np.digitize([pitch, yaw], bins) - 1
74
+
75
+ labels = binned_pose
76
+ cont_labels = torch.FloatTensor([pitch, yaw])
77
+
78
+
79
+ return img, labels, cont_labels, name
80
+
81
+ class Mpiigaze(Dataset):
82
+ def __init__(self, pathorg, root, transform, train, angle,fold=0):
83
+ self.transform = transform
84
+ self.root = root
85
+ self.orig_list_len = 0
86
+ self.lines = []
87
+ path=pathorg.copy()
88
+ if train==True:
89
+ path.pop(fold)
90
+ else:
91
+ path=path[fold]
92
+ if isinstance(path, list):
93
+ for i in path:
94
+ with open(i) as f:
95
+ lines = f.readlines()
96
+ lines.pop(0)
97
+ self.orig_list_len += len(lines)
98
+ for line in lines:
99
+ gaze2d = line.strip().split(" ")[7]
100
+ label = np.array(gaze2d.split(",")).astype("float")
101
+ if abs((label[0]*180/np.pi)) <= angle and abs((label[1]*180/np.pi)) <= angle:
102
+ self.lines.append(line)
103
+ else:
104
+ with open(path) as f:
105
+ lines = f.readlines()
106
+ lines.pop(0)
107
+ self.orig_list_len += len(lines)
108
+ for line in lines:
109
+ gaze2d = line.strip().split(" ")[7]
110
+ label = np.array(gaze2d.split(",")).astype("float")
111
+ if abs((label[0]*180/np.pi)) <= 42 and abs((label[1]*180/np.pi)) <= 42:
112
+ self.lines.append(line)
113
+
114
+ print("{} items removed from dataset that have an angle > {}".format(self.orig_list_len-len(self.lines),angle))
115
+
116
+ def __len__(self):
117
+ return len(self.lines)
118
+
119
+ def __getitem__(self, idx):
120
+ line = self.lines[idx]
121
+ line = line.strip().split(" ")
122
+
123
+ name = line[3]
124
+ gaze2d = line[7]
125
+ head2d = line[8]
126
+ lefteye = line[1]
127
+ righteye = line[2]
128
+ face = line[0]
129
+
130
+ label = np.array(gaze2d.split(",")).astype("float")
131
+ label = torch.from_numpy(label).type(torch.FloatTensor)
132
+
133
+
134
+ pitch = label[0]* 180 / np.pi
135
+ yaw = label[1]* 180 / np.pi
136
+
137
+ img = Image.open(os.path.join(self.root, face))
138
+
139
+ # fimg = cv2.imread(os.path.join(self.root, face))
140
+ # fimg = cv2.resize(fimg, (448, 448))/255.0
141
+ # fimg = fimg.transpose(2, 0, 1)
142
+ # img=torch.from_numpy(fimg).type(torch.FloatTensor)
143
+
144
+ if self.transform:
145
+ img = self.transform(img)
146
+
147
+ # Bin values
148
+ bins = np.array(range(-42, 42,3))
149
+ binned_pose = np.digitize([pitch, yaw], bins) - 1
150
+
151
+ labels = binned_pose
152
+ cont_labels = torch.FloatTensor([pitch, yaw])
153
+
154
+
155
+ return img, labels, cont_labels, name
156
+
157
+
models/L2CS-Net/l2cs/model.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.autograd import Variable
4
+ import math
5
+ import torch.nn.functional as F
6
+
7
+
8
+ class L2CS(nn.Module):
9
+ def __init__(self, block, layers, num_bins):
10
+ self.inplanes = 64
11
+ super(L2CS, self).__init__()
12
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,bias=False)
13
+ self.bn1 = nn.BatchNorm2d(64)
14
+ self.relu = nn.ReLU(inplace=True)
15
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
16
+ self.layer1 = self._make_layer(block, 64, layers[0])
17
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
18
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
19
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
20
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
21
+
22
+ self.fc_yaw_gaze = nn.Linear(512 * block.expansion, num_bins)
23
+ self.fc_pitch_gaze = nn.Linear(512 * block.expansion, num_bins)
24
+
25
+ # Vestigial layer from previous experiments
26
+ self.fc_finetune = nn.Linear(512 * block.expansion + 3, 3)
27
+
28
+ for m in self.modules():
29
+ if isinstance(m, nn.Conv2d):
30
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
31
+ m.weight.data.normal_(0, math.sqrt(2. / n))
32
+ elif isinstance(m, nn.BatchNorm2d):
33
+ m.weight.data.fill_(1)
34
+ m.bias.data.zero_()
35
+
36
+ def _make_layer(self, block, planes, blocks, stride=1):
37
+ downsample = None
38
+ if stride != 1 or self.inplanes != planes * block.expansion:
39
+ downsample = nn.Sequential(
40
+ nn.Conv2d(self.inplanes, planes * block.expansion,
41
+ kernel_size=1, stride=stride, bias=False),
42
+ nn.BatchNorm2d(planes * block.expansion),
43
+ )
44
+
45
+ layers = []
46
+ layers.append(block(self.inplanes, planes, stride, downsample))
47
+ self.inplanes = planes * block.expansion
48
+ for i in range(1, blocks):
49
+ layers.append(block(self.inplanes, planes))
50
+
51
+ return nn.Sequential(*layers)
52
+
53
+ def forward(self, x):
54
+ x = self.conv1(x)
55
+ x = self.bn1(x)
56
+ x = self.relu(x)
57
+ x = self.maxpool(x)
58
+
59
+ x = self.layer1(x)
60
+ x = self.layer2(x)
61
+ x = self.layer3(x)
62
+ x = self.layer4(x)
63
+ x = self.avgpool(x)
64
+ x = x.view(x.size(0), -1)
65
+
66
+
67
+ # gaze
68
+ pre_yaw_gaze = self.fc_yaw_gaze(x)
69
+ pre_pitch_gaze = self.fc_pitch_gaze(x)
70
+ return pre_yaw_gaze, pre_pitch_gaze
71
+
72
+
73
+
models/L2CS-Net/l2cs/pipeline.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pathlib
2
+ from typing import Union
3
+
4
+ import cv2
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ from dataclasses import dataclass
9
+ from face_detection import RetinaFace
10
+
11
+ from .utils import prep_input_numpy, getArch
12
+ from .results import GazeResultContainer
13
+
14
+
15
+ class Pipeline:
16
+
17
+ def __init__(
18
+ self,
19
+ weights: pathlib.Path,
20
+ arch: str,
21
+ device: str = 'cpu',
22
+ include_detector:bool = True,
23
+ confidence_threshold:float = 0.5
24
+ ):
25
+
26
+ # Save input parameters
27
+ self.weights = weights
28
+ self.include_detector = include_detector
29
+ self.device = device
30
+ self.confidence_threshold = confidence_threshold
31
+
32
+ # Create L2CS model
33
+ self.model = getArch(arch, 90)
34
+ self.model.load_state_dict(torch.load(self.weights, map_location=device))
35
+ self.model.to(self.device)
36
+ self.model.eval()
37
+
38
+ # Create RetinaFace if requested
39
+ if self.include_detector:
40
+
41
+ if device.type == 'cpu':
42
+ self.detector = RetinaFace()
43
+ else:
44
+ self.detector = RetinaFace(gpu_id=device.index)
45
+
46
+ self.softmax = nn.Softmax(dim=1)
47
+ self.idx_tensor = [idx for idx in range(90)]
48
+ self.idx_tensor = torch.FloatTensor(self.idx_tensor).to(self.device)
49
+
50
+ def step(self, frame: np.ndarray) -> GazeResultContainer:
51
+
52
+ # Creating containers
53
+ face_imgs = []
54
+ bboxes = []
55
+ landmarks = []
56
+ scores = []
57
+
58
+ if self.include_detector:
59
+ faces = self.detector(frame)
60
+
61
+ if faces is not None:
62
+ for box, landmark, score in faces:
63
+
64
+ # Apply threshold
65
+ if score < self.confidence_threshold:
66
+ continue
67
+
68
+ # Extract safe min and max of x,y
69
+ x_min=int(box[0])
70
+ if x_min < 0:
71
+ x_min = 0
72
+ y_min=int(box[1])
73
+ if y_min < 0:
74
+ y_min = 0
75
+ x_max=int(box[2])
76
+ y_max=int(box[3])
77
+
78
+ # Crop image
79
+ img = frame[y_min:y_max, x_min:x_max]
80
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
81
+ img = cv2.resize(img, (224, 224))
82
+ face_imgs.append(img)
83
+
84
+ # Save data
85
+ bboxes.append(box)
86
+ landmarks.append(landmark)
87
+ scores.append(score)
88
+
89
+ # Predict gaze
90
+ pitch, yaw = self.predict_gaze(np.stack(face_imgs))
91
+
92
+ else:
93
+
94
+ pitch = np.empty((0,1))
95
+ yaw = np.empty((0,1))
96
+
97
+ else:
98
+ pitch, yaw = self.predict_gaze(frame)
99
+
100
+ # Save data
101
+ results = GazeResultContainer(
102
+ pitch=pitch,
103
+ yaw=yaw,
104
+ bboxes=np.stack(bboxes),
105
+ landmarks=np.stack(landmarks),
106
+ scores=np.stack(scores)
107
+ )
108
+
109
+ return results
110
+
111
+ def predict_gaze(self, frame: Union[np.ndarray, torch.Tensor]):
112
+
113
+ # Prepare input
114
+ if isinstance(frame, np.ndarray):
115
+ img = prep_input_numpy(frame, self.device)
116
+ elif isinstance(frame, torch.Tensor):
117
+ img = frame
118
+ else:
119
+ raise RuntimeError("Invalid dtype for input")
120
+
121
+ # Predict
122
+ gaze_pitch, gaze_yaw = self.model(img)
123
+ pitch_predicted = self.softmax(gaze_pitch)
124
+ yaw_predicted = self.softmax(gaze_yaw)
125
+
126
+ # Get continuous predictions in degrees.
127
+ pitch_predicted = torch.sum(pitch_predicted.data * self.idx_tensor, dim=1) * 4 - 180
128
+ yaw_predicted = torch.sum(yaw_predicted.data * self.idx_tensor, dim=1) * 4 - 180
129
+
130
+ pitch_predicted= pitch_predicted.cpu().detach().numpy()* np.pi/180.0
131
+ yaw_predicted= yaw_predicted.cpu().detach().numpy()* np.pi/180.0
132
+
133
+ return pitch_predicted, yaw_predicted
models/L2CS-Net/l2cs/results.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ import numpy as np
3
+
4
+ @dataclass
5
+ class GazeResultContainer:
6
+
7
+ pitch: np.ndarray
8
+ yaw: np.ndarray
9
+ bboxes: np.ndarray
10
+ landmarks: np.ndarray
11
+ scores: np.ndarray
models/L2CS-Net/l2cs/utils.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import math
4
+ from math import cos, sin
5
+ from pathlib import Path
6
+ import subprocess
7
+ import re
8
+
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn as nn
12
+ import scipy.io as sio
13
+ import cv2
14
+ import torchvision
15
+ from torchvision import transforms
16
+
17
+ from .model import L2CS
18
+
19
+ transformations = transforms.Compose([
20
+ transforms.ToPILImage(),
21
+ transforms.Resize(448),
22
+ transforms.ToTensor(),
23
+ transforms.Normalize(
24
+ mean=[0.485, 0.456, 0.406],
25
+ std=[0.229, 0.224, 0.225]
26
+ )
27
+ ])
28
+
29
+ def atoi(text):
30
+ return int(text) if text.isdigit() else text
31
+
32
+ def natural_keys(text):
33
+ '''
34
+ alist.sort(key=natural_keys) sorts in human order
35
+ http://nedbatchelder.com/blog/200712/human_sorting.html
36
+ (See Toothy's implementation in the comments)
37
+ '''
38
+ return [ atoi(c) for c in re.split(r'(\d+)', text) ]
39
+
40
+ def prep_input_numpy(img:np.ndarray, device:str):
41
+ """Preparing a Numpy Array as input to L2CS-Net."""
42
+
43
+ if len(img.shape) == 4:
44
+ imgs = []
45
+ for im in img:
46
+ imgs.append(transformations(im))
47
+ img = torch.stack(imgs)
48
+ else:
49
+ img = transformations(img)
50
+
51
+ img = img.to(device)
52
+
53
+ if len(img.shape) == 3:
54
+ img = img.unsqueeze(0)
55
+
56
+ return img
57
+
58
+ def gazeto3d(gaze):
59
+ gaze_gt = np.zeros([3])
60
+ gaze_gt[0] = -np.cos(gaze[1]) * np.sin(gaze[0])
61
+ gaze_gt[1] = -np.sin(gaze[1])
62
+ gaze_gt[2] = -np.cos(gaze[1]) * np.cos(gaze[0])
63
+ return gaze_gt
64
+
65
+ def angular(gaze, label):
66
+ total = np.sum(gaze * label)
67
+ return np.arccos(min(total/(np.linalg.norm(gaze)* np.linalg.norm(label)), 0.9999999))*180/np.pi
68
+
69
+ def select_device(device='', batch_size=None):
70
+ # device = 'cpu' or '0' or '0,1,2,3'
71
+ s = f'YOLOv3 🚀 {git_describe() or date_modified()} torch {torch.__version__} ' # string
72
+ cpu = device.lower() == 'cpu'
73
+ if cpu:
74
+ os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False
75
+ elif device: # non-cpu device requested
76
+ os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable
77
+ # assert torch.cuda.is_available(), f'CUDA unavailable, invalid device {device} requested' # check availability
78
+
79
+ cuda = not cpu and torch.cuda.is_available()
80
+ if cuda:
81
+ devices = device.split(',') if device else range(torch.cuda.device_count()) # i.e. 0,1,6,7
82
+ n = len(devices) # device count
83
+ if n > 1 and batch_size: # check batch_size is divisible by device_count
84
+ assert batch_size % n == 0, f'batch-size {batch_size} not multiple of GPU count {n}'
85
+ space = ' ' * len(s)
86
+ for i, d in enumerate(devices):
87
+ p = torch.cuda.get_device_properties(i)
88
+ s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / 1024 ** 2}MB)\n" # bytes to MB
89
+ else:
90
+ s += 'CPU\n'
91
+
92
+ return torch.device('cuda:0' if cuda else 'cpu')
93
+
94
+ def spherical2cartesial(x):
95
+
96
+ output = torch.zeros(x.size(0),3)
97
+ output[:,2] = -torch.cos(x[:,1])*torch.cos(x[:,0])
98
+ output[:,0] = torch.cos(x[:,1])*torch.sin(x[:,0])
99
+ output[:,1] = torch.sin(x[:,1])
100
+
101
+ return output
102
+
103
+ def compute_angular_error(input,target):
104
+
105
+ input = spherical2cartesial(input)
106
+ target = spherical2cartesial(target)
107
+
108
+ input = input.view(-1,3,1)
109
+ target = target.view(-1,1,3)
110
+ output_dot = torch.bmm(target,input)
111
+ output_dot = output_dot.view(-1)
112
+ output_dot = torch.acos(output_dot)
113
+ output_dot = output_dot.data
114
+ output_dot = 180*torch.mean(output_dot)/math.pi
115
+ return output_dot
116
+
117
+ def softmax_temperature(tensor, temperature):
118
+ result = torch.exp(tensor / temperature)
119
+ result = torch.div(result, torch.sum(result, 1).unsqueeze(1).expand_as(result))
120
+ return result
121
+
122
+ def git_describe(path=Path(__file__).parent): # path must be a directory
123
+ # return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe
124
+ s = f'git -C {path} describe --tags --long --always'
125
+ try:
126
+ return subprocess.check_output(s, shell=True, stderr=subprocess.STDOUT).decode()[:-1]
127
+ except subprocess.CalledProcessError as e:
128
+ return '' # not a git repository
129
+
130
+ def getArch(arch,bins):
131
+ # Base network structure
132
+ if arch == 'ResNet18':
133
+ model = L2CS( torchvision.models.resnet.BasicBlock,[2, 2, 2, 2], bins)
134
+ elif arch == 'ResNet34':
135
+ model = L2CS( torchvision.models.resnet.BasicBlock,[3, 4, 6, 3], bins)
136
+ elif arch == 'ResNet101':
137
+ model = L2CS( torchvision.models.resnet.Bottleneck,[3, 4, 23, 3], bins)
138
+ elif arch == 'ResNet152':
139
+ model = L2CS( torchvision.models.resnet.Bottleneck,[3, 8, 36, 3], bins)
140
+ else:
141
+ if arch != 'ResNet50':
142
+ print('Invalid value for architecture is passed! '
143
+ 'The default value of ResNet50 will be used instead!')
144
+ model = L2CS( torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], bins)
145
+ return model
models/L2CS-Net/l2cs/vis.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ from .results import GazeResultContainer
4
+
5
+ def draw_gaze(a,b,c,d,image_in, pitchyaw, thickness=2, color=(255, 255, 0),sclae=2.0):
6
+ """Draw gaze angle on given image with a given eye positions."""
7
+ image_out = image_in
8
+ (h, w) = image_in.shape[:2]
9
+ length = c
10
+ pos = (int(a+c / 2.0), int(b+d / 2.0))
11
+ if len(image_out.shape) == 2 or image_out.shape[2] == 1:
12
+ image_out = cv2.cvtColor(image_out, cv2.COLOR_GRAY2BGR)
13
+ dx = -length * np.sin(pitchyaw[0]) * np.cos(pitchyaw[1])
14
+ dy = -length * np.sin(pitchyaw[1])
15
+ cv2.arrowedLine(image_out, tuple(np.round(pos).astype(np.int32)),
16
+ tuple(np.round([pos[0] + dx, pos[1] + dy]).astype(int)), color,
17
+ thickness, cv2.LINE_AA, tipLength=0.18)
18
+ return image_out
19
+
20
+ def draw_bbox(frame: np.ndarray, bbox: np.ndarray):
21
+
22
+ x_min=int(bbox[0])
23
+ if x_min < 0:
24
+ x_min = 0
25
+ y_min=int(bbox[1])
26
+ if y_min < 0:
27
+ y_min = 0
28
+ x_max=int(bbox[2])
29
+ y_max=int(bbox[3])
30
+
31
+ cv2.rectangle(frame, (x_min, y_min), (x_max, y_max), (0,255,0), 1)
32
+
33
+ return frame
34
+
35
+ def render(frame: np.ndarray, results: GazeResultContainer):
36
+
37
+ # Draw bounding boxes
38
+ for bbox in results.bboxes:
39
+ frame = draw_bbox(frame, bbox)
40
+
41
+ # Draw Gaze
42
+ for i in range(results.pitch.shape[0]):
43
+
44
+ bbox = results.bboxes[i]
45
+ pitch = results.pitch[i]
46
+ yaw = results.yaw[i]
47
+
48
+ # Extract safe min and max of x,y
49
+ x_min=int(bbox[0])
50
+ if x_min < 0:
51
+ x_min = 0
52
+ y_min=int(bbox[1])
53
+ if y_min < 0:
54
+ y_min = 0
55
+ x_max=int(bbox[2])
56
+ y_max=int(bbox[3])
57
+
58
+ # Compute sizes
59
+ bbox_width = x_max - x_min
60
+ bbox_height = y_max - y_min
61
+
62
+ draw_gaze(x_min,y_min,bbox_width, bbox_height,frame,(pitch,yaw),color=(0,0,255))
63
+
64
+ return frame
models/L2CS-Net/leave_one_out_eval.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+
4
+
5
+
6
+ def parse_args():
7
+ """Parse input arguments."""
8
+ parser = argparse.ArgumentParser(
9
+ description='gaze estimation using binned loss function.')
10
+ parser.add_argument(
11
+ '--evalpath', dest='evalpath', help='path for evaluating gaze test.',
12
+ default="evaluation\L2CS-gaze360-_standard-10", type=str)
13
+ parser.add_argument(
14
+ '--respath', dest='respath', help='path for saving result.',
15
+ default="evaluation\L2CS-gaze360-_standard-10", type=str)
16
+
17
+ if __name__ == '__main__':
18
+
19
+ args = parse_args()
20
+ evalpath =args.evalpath
21
+ respath=args.respath
22
+ if not os.path.exist(respath):
23
+ os.makedirs(respath)
24
+ with open(os.path.join(respath,"avg.log"), 'w') as outfile:
25
+ outfile.write("Average equal\n")
26
+
27
+ min=10.0
28
+ dirlist = os.listdir(evalpath)
29
+ dirlist.sort()
30
+ l=0.0
31
+ for j in range(50):
32
+ j=20
33
+ avg=0.0
34
+ h=j+3
35
+ for i in dirlist:
36
+ with open(evalpath+"/"+i+"/mpiigaze_binned.log") as myfile:
37
+
38
+ x=list(myfile)[h]
39
+ str1 = ""
40
+
41
+ # traverse in the string
42
+ for ele in x:
43
+ str1 += ele
44
+ split_string = str1.split("MAE:",1)[1]
45
+ avg+=float(split_string)
46
+
47
+ avg=avg/15.0
48
+ if avg<min:
49
+ min=avg
50
+ l=j+1
51
+ outfile.write("epoch"+str(j+1)+"= "+str(avg)+"\n")
52
+
53
+ outfile.write("min angular error equal= "+str(min)+"at epoch= "+str(l)+"\n")
54
+ print(min)
models/L2CS-Net/models/L2CSNet_gaze360.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8a7f3480d868dd48261e1d59f915b0ef0bb33ea12ea00938fb2168f212080665
3
+ size 95849977
models/L2CS-Net/models/README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ # Path to pre-trained models
models/L2CS-Net/pyproject.toml ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "l2cs"
3
+ version = "0.0.1"
4
+ description = "The official PyTorch implementation of L2CS-Net for gaze estimation and tracking"
5
+ authors = [
6
+ {name = "Ahmed Abderlrahman"},
7
+ {name = "Thorsten Hempel"}
8
+ ]
9
+ license = {file = "LICENSE.txt"}
10
+ readme = "README.md"
11
+ requires-python = ">3.6"
12
+
13
+ keywords = ["gaze", "estimation", "eye-tracking", "deep-learning", "pytorch"]
14
+
15
+ classifiers = [
16
+ "Programming Language :: Python :: 3"
17
+ ]
18
+
19
+ dependencies = [
20
+ 'matplotlib>=3.3.4',
21
+ 'numpy>=1.19.5',
22
+ 'opencv-python>=4.5.5',
23
+ 'pandas>=1.1.5',
24
+ 'Pillow>=8.4.0',
25
+ 'scipy>=1.5.4',
26
+ 'torch>=1.10.1',
27
+ 'torchvision>=0.11.2',
28
+ 'face_detection@git+https://github.com/elliottzheng/face-detection'
29
+ ]
30
+
31
+ [project.urls]
32
+ homepath = "https://github.com/Ahmednull/L2CS-Net"
33
+ repository = "https://github.com/Ahmednull/L2CS-Net"
34
+
35
+ [build-system]
36
+ requires = ["setuptools", "wheel"]
37
+ build-backend = "setuptools.build_meta"
38
+
39
+ # https://setuptools.pypa.io/en/stable/userguide/datafiles.html
40
+ [tool.setuptools]
41
+ include-package-data = true
42
+
43
+ [tool.setuptools.packages.find]
44
+ where = ["."]
models/L2CS-Net/test.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, argparse
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch.autograd import Variable
7
+ from torch.utils.data import DataLoader
8
+ from torchvision import transforms
9
+ import torch.backends.cudnn as cudnn
10
+ import torchvision
11
+
12
+ from l2cs import select_device, natural_keys, gazeto3d, angular, getArch, L2CS, Gaze360, Mpiigaze
13
+
14
+
15
+ def parse_args():
16
+ """Parse input arguments."""
17
+ parser = argparse.ArgumentParser(
18
+ description='Gaze estimation using L2CSNet .')
19
+ # Gaze360
20
+ parser.add_argument(
21
+ '--gaze360image_dir', dest='gaze360image_dir', help='Directory path for gaze images.',
22
+ default='datasets/Gaze360/Image', type=str)
23
+ parser.add_argument(
24
+ '--gaze360label_dir', dest='gaze360label_dir', help='Directory path for gaze labels.',
25
+ default='datasets/Gaze360/Label/test.label', type=str)
26
+ # mpiigaze
27
+ parser.add_argument(
28
+ '--gazeMpiimage_dir', dest='gazeMpiimage_dir', help='Directory path for gaze images.',
29
+ default='datasets/MPIIFaceGaze/Image', type=str)
30
+ parser.add_argument(
31
+ '--gazeMpiilabel_dir', dest='gazeMpiilabel_dir', help='Directory path for gaze labels.',
32
+ default='datasets/MPIIFaceGaze/Label', type=str)
33
+ # Important args -------------------------------------------------------------------------------------------------------
34
+ # ----------------------------------------------------------------------------------------------------------------------
35
+ parser.add_argument(
36
+ '--dataset', dest='dataset', help='gaze360, mpiigaze',
37
+ default= "gaze360", type=str)
38
+ parser.add_argument(
39
+ '--snapshot', dest='snapshot', help='Path to the folder contains models.',
40
+ default='output/snapshots/L2CS-gaze360-_loader-180-4-lr', type=str)
41
+ parser.add_argument(
42
+ '--evalpath', dest='evalpath', help='path for the output evaluating gaze test.',
43
+ default="evaluation/L2CS-gaze360-_loader-180-4-lr", type=str)
44
+ parser.add_argument(
45
+ '--gpu',dest='gpu_id', help='GPU device id to use [0]',
46
+ default="0", type=str)
47
+ parser.add_argument(
48
+ '--batch_size', dest='batch_size', help='Batch size.',
49
+ default=100, type=int)
50
+ parser.add_argument(
51
+ '--arch', dest='arch', help='Network architecture, can be: ResNet18, ResNet34, [ResNet50], ''ResNet101, ResNet152, Squeezenet_1_0, Squeezenet_1_1, MobileNetV2',
52
+ default='ResNet50', type=str)
53
+ # ---------------------------------------------------------------------------------------------------------------------
54
+ # Important args ------------------------------------------------------------------------------------------------------
55
+ args = parser.parse_args()
56
+ return args
57
+
58
+
59
+ def getArch(arch,bins):
60
+ # Base network structure
61
+ if arch == 'ResNet18':
62
+ model = L2CS( torchvision.models.resnet.BasicBlock,[2, 2, 2, 2], bins)
63
+ elif arch == 'ResNet34':
64
+ model = L2CS( torchvision.models.resnet.BasicBlock,[3, 4, 6, 3], bins)
65
+ elif arch == 'ResNet101':
66
+ model = L2CS( torchvision.models.resnet.Bottleneck,[3, 4, 23, 3], bins)
67
+ elif arch == 'ResNet152':
68
+ model = L2CS( torchvision.models.resnet.Bottleneck,[3, 8, 36, 3], bins)
69
+ else:
70
+ if arch != 'ResNet50':
71
+ print('Invalid value for architecture is passed! '
72
+ 'The default value of ResNet50 will be used instead!')
73
+ model = L2CS( torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], bins)
74
+ return model
75
+
76
+ if __name__ == '__main__':
77
+ args = parse_args()
78
+ cudnn.enabled = True
79
+ gpu = select_device(args.gpu_id, batch_size=args.batch_size)
80
+ batch_size=args.batch_size
81
+ arch=args.arch
82
+ data_set=args.dataset
83
+ evalpath =args.evalpath
84
+ snapshot_path = args.snapshot
85
+ bins=args.bins
86
+ angle=args.angle
87
+ bin_width=args.bin_width
88
+
89
+ transformations = transforms.Compose([
90
+ transforms.Resize(448),
91
+ transforms.ToTensor(),
92
+ transforms.Normalize(
93
+ mean=[0.485, 0.456, 0.406],
94
+ std=[0.229, 0.224, 0.225]
95
+ )
96
+ ])
97
+
98
+
99
+
100
+ if data_set=="gaze360":
101
+
102
+ gaze_dataset=Gaze360(args.gaze360label_dir,args.gaze360image_dir, transformations, 180, 4, train=False)
103
+ test_loader = torch.utils.data.DataLoader(
104
+ dataset=gaze_dataset,
105
+ batch_size=batch_size,
106
+ shuffle=False,
107
+ num_workers=4,
108
+ pin_memory=True)
109
+
110
+
111
+
112
+ if not os.path.exists(evalpath):
113
+ os.makedirs(evalpath)
114
+
115
+
116
+ # list all epochs for testing
117
+ folder = os.listdir(snapshot_path)
118
+ folder.sort(key=natural_keys)
119
+ softmax = nn.Softmax(dim=1)
120
+ with open(os.path.join(evalpath,data_set+".log"), 'w') as outfile:
121
+ configuration = f"\ntest configuration = gpu_id={gpu}, batch_size={batch_size}, model_arch={arch}\nStart testing dataset={data_set}----------------------------------------\n"
122
+ print(configuration)
123
+ outfile.write(configuration)
124
+ epoch_list=[]
125
+ avg_yaw=[]
126
+ avg_pitch=[]
127
+ avg_MAE=[]
128
+ for epochs in folder:
129
+ # Base network structure
130
+ model=getArch(arch, 90)
131
+ saved_state_dict = torch.load(os.path.join(snapshot_path, epochs))
132
+ model.load_state_dict(saved_state_dict)
133
+ model.cuda(gpu)
134
+ model.eval()
135
+ total = 0
136
+ idx_tensor = [idx for idx in range(90)]
137
+ idx_tensor = torch.FloatTensor(idx_tensor).cuda(gpu)
138
+ avg_error = .0
139
+
140
+
141
+ with torch.no_grad():
142
+ for j, (images, labels, cont_labels, name) in enumerate(test_loader):
143
+ images = Variable(images).cuda(gpu)
144
+ total += cont_labels.size(0)
145
+
146
+ label_pitch = cont_labels[:,0].float()*np.pi/180
147
+ label_yaw = cont_labels[:,1].float()*np.pi/180
148
+
149
+
150
+ gaze_pitch, gaze_yaw = model(images)
151
+
152
+ # Binned predictions
153
+ _, pitch_bpred = torch.max(gaze_pitch.data, 1)
154
+ _, yaw_bpred = torch.max(gaze_yaw.data, 1)
155
+
156
+
157
+ # Continuous predictions
158
+ pitch_predicted = softmax(gaze_pitch)
159
+ yaw_predicted = softmax(gaze_yaw)
160
+
161
+ # mapping from binned (0 to 28) to angels (-180 to 180)
162
+ pitch_predicted = torch.sum(pitch_predicted * idx_tensor, 1).cpu() * 4 - 180
163
+ yaw_predicted = torch.sum(yaw_predicted * idx_tensor, 1).cpu() * 4 - 180
164
+
165
+ pitch_predicted = pitch_predicted*np.pi/180
166
+ yaw_predicted = yaw_predicted*np.pi/180
167
+
168
+ for p,y,pl,yl in zip(pitch_predicted,yaw_predicted,label_pitch,label_yaw):
169
+ avg_error += angular(gazeto3d([p,y]), gazeto3d([pl,yl]))
170
+
171
+
172
+
173
+ x = ''.join(filter(lambda i: i.isdigit(), epochs))
174
+ epoch_list.append(x)
175
+ avg_MAE.append(avg_error/total)
176
+ loger = f"[{epochs}---{args.dataset}] Total Num:{total},MAE:{avg_error/total}\n"
177
+ outfile.write(loger)
178
+ print(loger)
179
+
180
+ fig = plt.figure(figsize=(14, 8))
181
+ plt.xlabel('epoch')
182
+ plt.ylabel('avg')
183
+ plt.title('Gaze angular error')
184
+ plt.legend()
185
+ plt.plot(epoch_list, avg_MAE, color='k', label='mae')
186
+ fig.savefig(os.path.join(evalpath,data_set+".png"), format='png')
187
+ plt.show()
188
+
189
+
190
+
191
+ elif data_set=="mpiigaze":
192
+ model_used=getArch(arch, bins)
193
+
194
+ for fold in range(15):
195
+ folder = os.listdir(args.gazeMpiilabel_dir)
196
+ folder.sort()
197
+ testlabelpathombined = [os.path.join(args.gazeMpiilabel_dir, j) for j in folder]
198
+ gaze_dataset=Mpiigaze(testlabelpathombined,args.gazeMpiimage_dir, transformations, False, angle, fold)
199
+
200
+ test_loader = torch.utils.data.DataLoader(
201
+ dataset=gaze_dataset,
202
+ batch_size=batch_size,
203
+ shuffle=True,
204
+ num_workers=4,
205
+ pin_memory=True)
206
+
207
+
208
+ if not os.path.exists(os.path.join(evalpath, f"fold"+str(fold))):
209
+ os.makedirs(os.path.join(evalpath, f"fold"+str(fold)))
210
+
211
+ # list all epochs for testing
212
+ folder = os.listdir(os.path.join(snapshot_path,"fold"+str(fold)))
213
+ folder.sort(key=natural_keys)
214
+
215
+ softmax = nn.Softmax(dim=1)
216
+ with open(os.path.join(evalpath, os.path.join("fold"+str(fold), data_set+".log")), 'w') as outfile:
217
+ configuration = f"\ntest configuration equal gpu_id={gpu}, batch_size={batch_size}, model_arch={arch}\nStart testing dataset={data_set}, fold={fold}---------------------------------------\n"
218
+ print(configuration)
219
+ outfile.write(configuration)
220
+ epoch_list=[]
221
+ avg_MAE=[]
222
+ for epochs in folder:
223
+ model=model_used
224
+ saved_state_dict = torch.load(os.path.join(snapshot_path+"/fold"+str(fold),epochs))
225
+ model= nn.DataParallel(model,device_ids=[0])
226
+ model.load_state_dict(saved_state_dict)
227
+ model.cuda(gpu)
228
+ model.eval()
229
+ total = 0
230
+ idx_tensor = [idx for idx in range(28)]
231
+ idx_tensor = torch.FloatTensor(idx_tensor).cuda(gpu)
232
+ avg_error = .0
233
+ with torch.no_grad():
234
+ for j, (images, labels, cont_labels, name) in enumerate(test_loader):
235
+ images = Variable(images).cuda(gpu)
236
+ total += cont_labels.size(0)
237
+
238
+ label_pitch = cont_labels[:,0].float()*np.pi/180
239
+ label_yaw = cont_labels[:,1].float()*np.pi/180
240
+
241
+
242
+ gaze_pitch, gaze_yaw = model(images)
243
+
244
+ # Binned predictions
245
+ _, pitch_bpred = torch.max(gaze_pitch.data, 1)
246
+ _, yaw_bpred = torch.max(gaze_yaw.data, 1)
247
+
248
+
249
+ # Continuous predictions
250
+ pitch_predicted = softmax(gaze_pitch)
251
+ yaw_predicted = softmax(gaze_yaw)
252
+
253
+ # mapping from binned (0 to 28) to angels (-42 to 42)
254
+ pitch_predicted = \
255
+ torch.sum(pitch_predicted * idx_tensor, 1).cpu() * 3 - 42
256
+ yaw_predicted = \
257
+ torch.sum(yaw_predicted * idx_tensor, 1).cpu() * 3 - 42
258
+
259
+
260
+ pitch_predicted = pitch_predicted*np.pi/180
261
+ yaw_predicted = yaw_predicted*np.pi/180
262
+
263
+ for p,y,pl,yl in zip(pitch_predicted, yaw_predicted, label_pitch, label_yaw):
264
+ avg_error += angular(gazeto3d([p,y]), gazeto3d([pl,yl]))
265
+
266
+
267
+ x = ''.join(filter(lambda i: i.isdigit(), epochs))
268
+ epoch_list.append(x)
269
+ avg_MAE.append(avg_error/ total)
270
+ loger = f"[{epochs}---{args.dataset}] Total Num:{total},MAE:{avg_error/total} \n"
271
+ outfile.write(loger)
272
+ print(loger)
273
+
274
+ fig = plt.figure(figsize=(14, 8))
275
+ plt.xlabel('epoch')
276
+ plt.ylabel('avg')
277
+ plt.title('Gaze angular error')
278
+ plt.legend()
279
+ plt.plot(epoch_list, avg_MAE, color='k', label='mae')
280
+ fig.savefig(os.path.join(evalpath, os.path.join("fold"+str(fold), data_set+".png")), format='png')
281
+ # plt.show()
282
+
283
+
284
+
models/L2CS-Net/train.py ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import time
4
+
5
+ import torch.utils.model_zoo as model_zoo
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch.autograd import Variable
9
+ from torch.utils.data import DataLoader
10
+ from torchvision import transforms
11
+ import torch.backends.cudnn as cudnn
12
+ import torchvision
13
+
14
+ from l2cs import L2CS, select_device, Gaze360, Mpiigaze
15
+
16
+
17
+ def parse_args():
18
+ """Parse input arguments."""
19
+ parser = argparse.ArgumentParser(description='Gaze estimation using L2CSNet.')
20
+ # Gaze360
21
+ parser.add_argument(
22
+ '--gaze360image_dir', dest='gaze360image_dir', help='Directory path for gaze images.',
23
+ default='datasets/Gaze360/Image', type=str)
24
+ parser.add_argument(
25
+ '--gaze360label_dir', dest='gaze360label_dir', help='Directory path for gaze labels.',
26
+ default='datasets/Gaze360/Label/train.label', type=str)
27
+ # mpiigaze
28
+ parser.add_argument(
29
+ '--gazeMpiimage_dir', dest='gazeMpiimage_dir', help='Directory path for gaze images.',
30
+ default='datasets/MPIIFaceGaze/Image', type=str)
31
+ parser.add_argument(
32
+ '--gazeMpiilabel_dir', dest='gazeMpiilabel_dir', help='Directory path for gaze labels.',
33
+ default='datasets/MPIIFaceGaze/Label', type=str)
34
+
35
+ # Important args -------------------------------------------------------------------------------------------------------
36
+ # ----------------------------------------------------------------------------------------------------------------------
37
+ parser.add_argument(
38
+ '--dataset', dest='dataset', help='mpiigaze, rtgene, gaze360, ethgaze',
39
+ default= "gaze360", type=str)
40
+ parser.add_argument(
41
+ '--output', dest='output', help='Path of output models.',
42
+ default='output/snapshots/', type=str)
43
+ parser.add_argument(
44
+ '--snapshot', dest='snapshot', help='Path of model snapshot.',
45
+ default='', type=str)
46
+ parser.add_argument(
47
+ '--gpu', dest='gpu_id', help='GPU device id to use [0] or multiple 0,1,2,3',
48
+ default='0', type=str)
49
+ parser.add_argument(
50
+ '--num_epochs', dest='num_epochs', help='Maximum number of training epochs.',
51
+ default=60, type=int)
52
+ parser.add_argument(
53
+ '--batch_size', dest='batch_size', help='Batch size.',
54
+ default=1, type=int)
55
+ parser.add_argument(
56
+ '--arch', dest='arch', help='Network architecture, can be: ResNet18, ResNet34, [ResNet50], ''ResNet101, ResNet152, Squeezenet_1_0, Squeezenet_1_1, MobileNetV2',
57
+ default='ResNet50', type=str)
58
+ parser.add_argument(
59
+ '--alpha', dest='alpha', help='Regression loss coefficient.',
60
+ default=1, type=float)
61
+ parser.add_argument(
62
+ '--lr', dest='lr', help='Base learning rate.',
63
+ default=0.00001, type=float)
64
+ # ---------------------------------------------------------------------------------------------------------------------
65
+ # Important args ------------------------------------------------------------------------------------------------------
66
+ args = parser.parse_args()
67
+ return args
68
+
69
+ def get_ignored_params(model):
70
+ # Generator function that yields ignored params.
71
+ b = [model.conv1, model.bn1, model.fc_finetune]
72
+ for i in range(len(b)):
73
+ for module_name, module in b[i].named_modules():
74
+ if 'bn' in module_name:
75
+ module.eval()
76
+ for name, param in module.named_parameters():
77
+ yield param
78
+
79
+ def get_non_ignored_params(model):
80
+ # Generator function that yields params that will be optimized.
81
+ b = [model.layer1, model.layer2, model.layer3, model.layer4]
82
+ for i in range(len(b)):
83
+ for module_name, module in b[i].named_modules():
84
+ if 'bn' in module_name:
85
+ module.eval()
86
+ for name, param in module.named_parameters():
87
+ yield param
88
+
89
+ def get_fc_params(model):
90
+ # Generator function that yields fc layer params.
91
+ b = [model.fc_yaw_gaze, model.fc_pitch_gaze]
92
+ for i in range(len(b)):
93
+ for module_name, module in b[i].named_modules():
94
+ for name, param in module.named_parameters():
95
+ yield param
96
+
97
+ def load_filtered_state_dict(model, snapshot):
98
+ # By user apaszke from discuss.pytorch.org
99
+ model_dict = model.state_dict()
100
+ snapshot = {k: v for k, v in snapshot.items() if k in model_dict}
101
+ model_dict.update(snapshot)
102
+ model.load_state_dict(model_dict)
103
+
104
+
105
+ def getArch_weights(arch, bins):
106
+ if arch == 'ResNet18':
107
+ model = L2CS(torchvision.models.resnet.BasicBlock, [2, 2, 2, 2], bins)
108
+ pre_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
109
+ elif arch == 'ResNet34':
110
+ model = L2CS(torchvision.models.resnet.BasicBlock, [3, 4, 6, 3], bins)
111
+ pre_url = 'https://download.pytorch.org/models/resnet34-333f7ec4.pth'
112
+ elif arch == 'ResNet101':
113
+ model = L2CS(torchvision.models.resnet.Bottleneck, [3, 4, 23, 3], bins)
114
+ pre_url = 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth'
115
+ elif arch == 'ResNet152':
116
+ model = L2CS(torchvision.models.resnet.Bottleneck,[3, 8, 36, 3], bins)
117
+ pre_url = 'https://download.pytorch.org/models/resnet152-b121ed2d.pth'
118
+ else:
119
+ if arch != 'ResNet50':
120
+ print('Invalid value for architecture is passed! '
121
+ 'The default value of ResNet50 will be used instead!')
122
+ model = L2CS(torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], bins)
123
+ pre_url = 'https://download.pytorch.org/models/resnet50-19c8e357.pth'
124
+
125
+ return model, pre_url
126
+
127
+ if __name__ == '__main__':
128
+ args = parse_args()
129
+ cudnn.enabled = True
130
+ num_epochs = args.num_epochs
131
+ batch_size = args.batch_size
132
+ gpu = select_device(args.gpu_id, batch_size=args.batch_size)
133
+ data_set=args.dataset
134
+ alpha = args.alpha
135
+ output=args.output
136
+
137
+
138
+ transformations = transforms.Compose([
139
+ transforms.Resize(448),
140
+ transforms.ToTensor(),
141
+ transforms.Normalize(
142
+ mean=[0.485, 0.456, 0.406],
143
+ std=[0.229, 0.224, 0.225]
144
+ )
145
+ ])
146
+
147
+
148
+
149
+ if data_set=="gaze360":
150
+ model, pre_url = getArch_weights(args.arch, 90)
151
+ if args.snapshot == '':
152
+ load_filtered_state_dict(model, model_zoo.load_url(pre_url))
153
+ else:
154
+ saved_state_dict = torch.load(args.snapshot)
155
+ model.load_state_dict(saved_state_dict)
156
+
157
+
158
+ model.cuda(gpu)
159
+ dataset=Gaze360(args.gaze360label_dir, args.gaze360image_dir, transformations, 180, 4)
160
+ print('Loading data.')
161
+ train_loader_gaze = DataLoader(
162
+ dataset=dataset,
163
+ batch_size=int(batch_size),
164
+ shuffle=True,
165
+ num_workers=0,
166
+ pin_memory=True)
167
+ torch.backends.cudnn.benchmark = True
168
+
169
+ summary_name = '{}_{}'.format('L2CS-gaze360-', int(time.time()))
170
+ output=os.path.join(output, summary_name)
171
+ if not os.path.exists(output):
172
+ os.makedirs(output)
173
+
174
+
175
+ criterion = nn.CrossEntropyLoss().cuda(gpu)
176
+ reg_criterion = nn.MSELoss().cuda(gpu)
177
+ softmax = nn.Softmax(dim=1).cuda(gpu)
178
+ idx_tensor = [idx for idx in range(90)]
179
+ idx_tensor = Variable(torch.FloatTensor(idx_tensor)).cuda(gpu)
180
+
181
+
182
+ # Optimizer gaze
183
+ optimizer_gaze = torch.optim.Adam([
184
+ {'params': get_ignored_params(model), 'lr': 0},
185
+ {'params': get_non_ignored_params(model), 'lr': args.lr},
186
+ {'params': get_fc_params(model), 'lr': args.lr}
187
+ ], args.lr)
188
+
189
+
190
+ configuration = f"\ntrain configuration, gpu_id={args.gpu_id}, batch_size={batch_size}, model_arch={args.arch}\nStart testing dataset={data_set}, loader={len(train_loader_gaze)}------------------------- \n"
191
+ print(configuration)
192
+ for epoch in range(num_epochs):
193
+ sum_loss_pitch_gaze = sum_loss_yaw_gaze = iter_gaze = 0
194
+
195
+
196
+ for i, (images_gaze, labels_gaze, cont_labels_gaze,name) in enumerate(train_loader_gaze):
197
+ images_gaze = Variable(images_gaze).cuda(gpu)
198
+
199
+ # Binned labels
200
+ label_pitch_gaze = Variable(labels_gaze[:, 0]).cuda(gpu)
201
+ label_yaw_gaze = Variable(labels_gaze[:, 1]).cuda(gpu)
202
+
203
+ # Continuous labels
204
+ label_pitch_cont_gaze = Variable(cont_labels_gaze[:, 0]).cuda(gpu)
205
+ label_yaw_cont_gaze = Variable(cont_labels_gaze[:, 1]).cuda(gpu)
206
+
207
+ pitch, yaw = model(images_gaze)
208
+
209
+ # Cross entropy loss
210
+ loss_pitch_gaze = criterion(pitch, label_pitch_gaze)
211
+ loss_yaw_gaze = criterion(yaw, label_yaw_gaze)
212
+
213
+ # MSE loss
214
+ pitch_predicted = softmax(pitch)
215
+ yaw_predicted = softmax(yaw)
216
+
217
+ pitch_predicted = \
218
+ torch.sum(pitch_predicted * idx_tensor, 1) * 4 - 180
219
+ yaw_predicted = \
220
+ torch.sum(yaw_predicted * idx_tensor, 1) * 4 - 180
221
+
222
+ loss_reg_pitch = reg_criterion(
223
+ pitch_predicted, label_pitch_cont_gaze)
224
+ loss_reg_yaw = reg_criterion(
225
+ yaw_predicted, label_yaw_cont_gaze)
226
+
227
+ # Total loss
228
+ loss_pitch_gaze += alpha * loss_reg_pitch
229
+ loss_yaw_gaze += alpha * loss_reg_yaw
230
+
231
+ sum_loss_pitch_gaze += loss_pitch_gaze
232
+ sum_loss_yaw_gaze += loss_yaw_gaze
233
+
234
+ loss_seq = [loss_pitch_gaze, loss_yaw_gaze]
235
+ grad_seq = [torch.tensor(1.0).cuda(gpu) for _ in range(len(loss_seq))]
236
+ optimizer_gaze.zero_grad(set_to_none=True)
237
+ torch.autograd.backward(loss_seq, grad_seq)
238
+ optimizer_gaze.step()
239
+ # scheduler.step()
240
+
241
+ iter_gaze += 1
242
+
243
+ if (i+1) % 100 == 0:
244
+ print('Epoch [%d/%d], Iter [%d/%d] Losses: '
245
+ 'Gaze Yaw %.4f,Gaze Pitch %.4f' % (
246
+ epoch+1,
247
+ num_epochs,
248
+ i+1,
249
+ len(dataset)//batch_size,
250
+ sum_loss_pitch_gaze/iter_gaze,
251
+ sum_loss_yaw_gaze/iter_gaze
252
+ )
253
+ )
254
+
255
+
256
+ if epoch % 1 == 0 and epoch < num_epochs:
257
+ print('Taking snapshot...',
258
+ torch.save(model.state_dict(),
259
+ output +'/'+
260
+ '_epoch_' + str(epoch+1) + '.pkl')
261
+ )
262
+
263
+
264
+
265
+ elif data_set=="mpiigaze":
266
+ folder = os.listdir(args.gazeMpiilabel_dir)
267
+ folder.sort()
268
+ testlabelpathombined = [os.path.join(args.gazeMpiilabel_dir, j) for j in folder]
269
+ for fold in range(15):
270
+ model, pre_url = getArch_weights(args.arch, 28)
271
+ load_filtered_state_dict(model, model_zoo.load_url(pre_url))
272
+ model = nn.DataParallel(model)
273
+ model.to(gpu)
274
+ print('Loading data.')
275
+ dataset=Mpiigaze(testlabelpathombined,args.gazeMpiimage_dir, transformations, True, fold)
276
+ train_loader_gaze = DataLoader(
277
+ dataset=dataset,
278
+ batch_size=int(batch_size),
279
+ shuffle=True,
280
+ num_workers=4,
281
+ pin_memory=True)
282
+ torch.backends.cudnn.benchmark = True
283
+
284
+ summary_name = '{}_{}'.format('L2CS-mpiigaze', int(time.time()))
285
+
286
+
287
+ if not os.path.exists(os.path.join(output+'/{}'.format(summary_name),'fold' + str(fold))):
288
+ os.makedirs(os.path.join(output+'/{}'.format(summary_name),'fold' + str(fold)))
289
+
290
+
291
+ criterion = nn.CrossEntropyLoss().cuda(gpu)
292
+ reg_criterion = nn.MSELoss().cuda(gpu)
293
+ softmax = nn.Softmax(dim=1).cuda(gpu)
294
+ idx_tensor = [idx for idx in range(28)]
295
+ idx_tensor = Variable(torch.FloatTensor(idx_tensor)).cuda(gpu)
296
+
297
+ # Optimizer gaze
298
+ optimizer_gaze = torch.optim.Adam([
299
+ {'params': get_ignored_params(model, args.arch), 'lr': 0},
300
+ {'params': get_non_ignored_params(model, args.arch), 'lr': args.lr},
301
+ {'params': get_fc_params(model, args.arch), 'lr': args.lr}
302
+ ], args.lr)
303
+
304
+
305
+
306
+ configuration = f"\ntrain configuration, gpu_id={args.gpu_id}, batch_size={batch_size}, model_arch={args.arch}\n Start training dataset={data_set}, loader={len(train_loader_gaze)}, fold={fold}--------------\n"
307
+ print(configuration)
308
+ for epoch in range(num_epochs):
309
+ sum_loss_pitch_gaze = sum_loss_yaw_gaze = iter_gaze = 0
310
+
311
+
312
+ for i, (images_gaze, labels_gaze, cont_labels_gaze,name) in enumerate(train_loader_gaze):
313
+ images_gaze = Variable(images_gaze).cuda(gpu)
314
+
315
+ # Binned labels
316
+ label_pitch_gaze = Variable(labels_gaze[:, 0]).cuda(gpu)
317
+ label_yaw_gaze = Variable(labels_gaze[:, 1]).cuda(gpu)
318
+
319
+ # Continuous labels
320
+ label_pitch_cont_gaze = Variable(cont_labels_gaze[:, 0]).cuda(gpu)
321
+ label_yaw_cont_gaze = Variable(cont_labels_gaze[:, 1]).cuda(gpu)
322
+
323
+ pitch, yaw = model(images_gaze)
324
+
325
+ # Cross entropy loss
326
+ loss_pitch_gaze = criterion(pitch, label_pitch_gaze)
327
+ loss_yaw_gaze = criterion(yaw, label_yaw_gaze)
328
+
329
+ # MSE loss
330
+ pitch_predicted = softmax(pitch)
331
+ yaw_predicted = softmax(yaw)
332
+
333
+ pitch_predicted = \
334
+ torch.sum(pitch_predicted * idx_tensor, 1) * 3 - 42
335
+ yaw_predicted = \
336
+ torch.sum(yaw_predicted * idx_tensor, 1) * 3 - 42
337
+
338
+ loss_reg_pitch = reg_criterion(
339
+ pitch_predicted, label_pitch_cont_gaze)
340
+ loss_reg_yaw = reg_criterion(
341
+ yaw_predicted, label_yaw_cont_gaze)
342
+
343
+ # Total loss
344
+ loss_pitch_gaze += alpha * loss_reg_pitch
345
+ loss_yaw_gaze += alpha * loss_reg_yaw
346
+
347
+ sum_loss_pitch_gaze += loss_pitch_gaze
348
+ sum_loss_yaw_gaze += loss_yaw_gaze
349
+
350
+ loss_seq = [loss_pitch_gaze, loss_yaw_gaze]
351
+ grad_seq = \
352
+ [torch.tensor(1.0).cuda(gpu) for _ in range(len(loss_seq))]
353
+
354
+ optimizer_gaze.zero_grad(set_to_none=True)
355
+ torch.autograd.backward(loss_seq, grad_seq)
356
+ optimizer_gaze.step()
357
+
358
+ iter_gaze += 1
359
+
360
+ if (i+1) % 100 == 0:
361
+ print('Epoch [%d/%d], Iter [%d/%d] Losses: '
362
+ 'Gaze Yaw %.4f,Gaze Pitch %.4f' % (
363
+ epoch+1,
364
+ num_epochs,
365
+ i+1,
366
+ len(dataset)//batch_size,
367
+ sum_loss_pitch_gaze/iter_gaze,
368
+ sum_loss_yaw_gaze/iter_gaze
369
+ )
370
+ )
371
+
372
+
373
+
374
+ # Save models at numbered epochs.
375
+ if epoch % 1 == 0 and epoch < num_epochs:
376
+ print('Taking snapshot...',
377
+ torch.save(model.state_dict(),
378
+ output+'/fold' + str(fold) +'/'+
379
+ '_epoch_' + str(epoch+1) + '.pkl')
380
+ )
381
+
382
+
383
+
384
+
models/gaze_calibration.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 9-point gaze calibration for L2CS-Net
2
+ # Maps raw gaze angles -> normalised screen coords via polynomial least-squares.
3
+ # Centre point is the bias reference (subtracted from all readings).
4
+
5
+ import numpy as np
6
+ from dataclasses import dataclass, field
7
+
8
+ # 3x3 grid, centre first (bias ref), then row by row
9
+ DEFAULT_TARGETS = [
10
+ (0.5, 0.5),
11
+ (0.15, 0.15), (0.50, 0.15), (0.85, 0.15),
12
+ (0.15, 0.50), (0.85, 0.50),
13
+ (0.15, 0.85), (0.50, 0.85), (0.85, 0.85),
14
+ ]
15
+
16
+
17
+ @dataclass
18
+ class _PointSamples:
19
+ target_x: float
20
+ target_y: float
21
+ yaws: list = field(default_factory=list)
22
+ pitches: list = field(default_factory=list)
23
+
24
+
25
+ def _iqr_filter(values):
26
+ if len(values) < 4:
27
+ return values
28
+ arr = np.array(values)
29
+ q1, q3 = np.percentile(arr, [25, 75])
30
+ iqr = q3 - q1
31
+ lo, hi = q1 - 1.5 * iqr, q3 + 1.5 * iqr
32
+ return arr[(arr >= lo) & (arr <= hi)].tolist()
33
+
34
+
35
+ class GazeCalibration:
36
+
37
+ def __init__(self, targets=None):
38
+ self._targets = targets or list(DEFAULT_TARGETS)
39
+ self._points = [_PointSamples(tx, ty) for tx, ty in self._targets]
40
+ self._current_idx = 0
41
+ self._fitted = False
42
+ self._W = None # (6, 2) polynomial weights
43
+ self._yaw_bias = 0.0
44
+ self._pitch_bias = 0.0
45
+
46
+ @property
47
+ def num_points(self):
48
+ return len(self._targets)
49
+
50
+ @property
51
+ def current_index(self):
52
+ return self._current_idx
53
+
54
+ @property
55
+ def current_target(self):
56
+ if self._current_idx < len(self._targets):
57
+ return self._targets[self._current_idx]
58
+ return self._targets[-1]
59
+
60
+ @property
61
+ def is_complete(self):
62
+ return self._current_idx >= len(self._targets)
63
+
64
+ @property
65
+ def is_fitted(self):
66
+ return self._fitted
67
+
68
+ def collect_sample(self, yaw_rad, pitch_rad):
69
+ if self._current_idx >= len(self._points):
70
+ return
71
+ pt = self._points[self._current_idx]
72
+ pt.yaws.append(float(yaw_rad))
73
+ pt.pitches.append(float(pitch_rad))
74
+
75
+ def advance(self):
76
+ self._current_idx += 1
77
+ return self._current_idx < len(self._targets)
78
+
79
+ @staticmethod
80
+ def _poly_features(yaw, pitch):
81
+ # [yaw^2, pitch^2, yaw*pitch, yaw, pitch, 1]
82
+ return np.array([yaw**2, pitch**2, yaw * pitch, yaw, pitch, 1.0],
83
+ dtype=np.float64)
84
+
85
+ def fit(self):
86
+ # bias from centre point (index 0)
87
+ center = self._points[0]
88
+ center_yaws = _iqr_filter(center.yaws)
89
+ center_pitches = _iqr_filter(center.pitches)
90
+ if len(center_yaws) < 2 or len(center_pitches) < 2:
91
+ return False
92
+ self._yaw_bias = float(np.median(center_yaws))
93
+ self._pitch_bias = float(np.median(center_pitches))
94
+
95
+ rows_A, rows_B = [], []
96
+ for pt in self._points:
97
+ clean_yaws = _iqr_filter(pt.yaws)
98
+ clean_pitches = _iqr_filter(pt.pitches)
99
+ if len(clean_yaws) < 2 or len(clean_pitches) < 2:
100
+ continue
101
+ med_yaw = float(np.median(clean_yaws)) - self._yaw_bias
102
+ med_pitch = float(np.median(clean_pitches)) - self._pitch_bias
103
+ rows_A.append(self._poly_features(med_yaw, med_pitch))
104
+ rows_B.append([pt.target_x, pt.target_y])
105
+
106
+ if len(rows_A) < 5:
107
+ return False
108
+
109
+ A = np.array(rows_A, dtype=np.float64)
110
+ B = np.array(rows_B, dtype=np.float64)
111
+ try:
112
+ W, _, _, _ = np.linalg.lstsq(A, B, rcond=None)
113
+ self._W = W
114
+ self._fitted = True
115
+ return True
116
+ except np.linalg.LinAlgError:
117
+ return False
118
+
119
+ def predict(self, yaw_rad, pitch_rad):
120
+ if not self._fitted or self._W is None:
121
+ return 0.5, 0.5
122
+ feat = self._poly_features(yaw_rad - self._yaw_bias, pitch_rad - self._pitch_bias)
123
+ xy = feat @ self._W
124
+ return float(np.clip(xy[0], 0, 1)), float(np.clip(xy[1], 0, 1))
125
+
126
+ def to_dict(self):
127
+ return {
128
+ "targets": self._targets,
129
+ "fitted": self._fitted,
130
+ "current_index": self._current_idx,
131
+ "W": self._W.tolist() if self._W is not None else None,
132
+ "yaw_bias": self._yaw_bias,
133
+ "pitch_bias": self._pitch_bias,
134
+ }
135
+
136
+ @classmethod
137
+ def from_dict(cls, d):
138
+ cal = cls(targets=d.get("targets", DEFAULT_TARGETS))
139
+ cal._fitted = d.get("fitted", False)
140
+ cal._current_idx = d.get("current_index", 0)
141
+ cal._yaw_bias = d.get("yaw_bias", 0.0)
142
+ cal._pitch_bias = d.get("pitch_bias", 0.0)
143
+ w = d.get("W")
144
+ if w is not None:
145
+ cal._W = np.array(w, dtype=np.float64)
146
+ return cal
models/gaze_eye_fusion.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Fuses calibrated gaze position with eye openness (EAR) for focus detection.
2
+ # Takes L2CS gaze angles + MediaPipe landmarks, outputs screen coords + focus decision.
3
+
4
+ import math
5
+ import numpy as np
6
+
7
+ from .gaze_calibration import GazeCalibration
8
+ from .eye_scorer import compute_avg_ear
9
+
10
+ _EAR_BLINK = 0.18
11
+ _ON_SCREEN_MARGIN = 0.08
12
+
13
+
14
+ class GazeEyeFusion:
15
+
16
+ def __init__(self, calibration, ear_weight=0.3, gaze_weight=0.7, focus_threshold=0.52):
17
+ if not calibration.is_fitted:
18
+ raise ValueError("Calibration must be fitted first")
19
+ self._cal = calibration
20
+ self._ear_w = ear_weight
21
+ self._gaze_w = gaze_weight
22
+ self._threshold = focus_threshold
23
+ self._smooth_x = 0.5
24
+ self._smooth_y = 0.5
25
+ self._alpha = 0.5
26
+
27
+ def update(self, yaw_rad, pitch_rad, landmarks):
28
+ gx, gy = self._cal.predict(yaw_rad, pitch_rad)
29
+
30
+ # EMA smooth the gaze position
31
+ self._smooth_x += self._alpha * (gx - self._smooth_x)
32
+ self._smooth_y += self._alpha * (gy - self._smooth_y)
33
+ gx, gy = self._smooth_x, self._smooth_y
34
+
35
+ on_screen = (
36
+ -_ON_SCREEN_MARGIN <= gx <= 1.0 + _ON_SCREEN_MARGIN and
37
+ -_ON_SCREEN_MARGIN <= gy <= 1.0 + _ON_SCREEN_MARGIN
38
+ )
39
+
40
+ ear = None
41
+ ear_score = 1.0
42
+ if landmarks is not None:
43
+ ear = compute_avg_ear(landmarks)
44
+ ear_score = 0.0 if ear < _EAR_BLINK else min(ear / 0.30, 1.0)
45
+
46
+ # penalise gaze near screen edges
47
+ gaze_score = 1.0 if on_screen else 0.0
48
+ if on_screen:
49
+ dx = max(0.0, abs(gx - 0.5) - 0.3)
50
+ dy = max(0.0, abs(gy - 0.5) - 0.3)
51
+ gaze_score = max(0.0, 1.0 - math.sqrt(dx**2 + dy**2) * 5.0)
52
+
53
+ score = float(np.clip(self._gaze_w * gaze_score + self._ear_w * ear_score, 0, 1))
54
+
55
+ return {
56
+ "gaze_x": round(float(gx), 4),
57
+ "gaze_y": round(float(gy), 4),
58
+ "on_screen": on_screen,
59
+ "ear": round(ear, 4) if ear is not None else None,
60
+ "focus_score": round(score, 4),
61
+ "focused": score >= self._threshold,
62
+ }
63
+
64
+ def reset(self):
65
+ self._smooth_x = 0.5
66
+ self._smooth_y = 0.5
requirements.txt CHANGED
@@ -14,3 +14,7 @@ aiosqlite>=0.19.0
14
  pydantic>=2.0.0
15
  xgboost>=2.0.0
16
  clearml>=2.0.2
 
 
 
 
 
14
  pydantic>=2.0.0
15
  xgboost>=2.0.0
16
  clearml>=2.0.2
17
+ torch>=1.10.1
18
+ torchvision>=0.11.2
19
+ face_detection @ git+https://github.com/elliottzheng/face-detection
20
+ gdown>=5.0.0
src/components/CalibrationOverlay.jsx ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import React, { useState, useEffect, useRef, useCallback } from 'react';
2
+
3
+ const COLLECT_MS = 2000;
4
+ const CENTER_MS = 3000; // centre point gets extra time (bias reference)
5
+
6
+ function CalibrationOverlay({ calibration, videoManager }) {
7
+ const [progress, setProgress] = useState(0);
8
+ const timerRef = useRef(null);
9
+ const startRef = useRef(null);
10
+ const overlayRef = useRef(null);
11
+
12
+ const enterFullscreen = useCallback(() => {
13
+ const el = overlayRef.current;
14
+ if (!el) return;
15
+ const req = el.requestFullscreen || el.webkitRequestFullscreen || el.msRequestFullscreen;
16
+ if (req) req.call(el).catch(() => {});
17
+ }, []);
18
+
19
+ const exitFullscreen = useCallback(() => {
20
+ if (document.fullscreenElement || document.webkitFullscreenElement) {
21
+ const exit = document.exitFullscreen || document.webkitExitFullscreen || document.msExitFullscreen;
22
+ if (exit) exit.call(document).catch(() => {});
23
+ }
24
+ }, []);
25
+
26
+ useEffect(() => {
27
+ if (calibration && calibration.active && !calibration.done) {
28
+ const t = setTimeout(enterFullscreen, 100);
29
+ return () => clearTimeout(t);
30
+ }
31
+ }, [calibration?.active]);
32
+
33
+ useEffect(() => {
34
+ if (!calibration || !calibration.active) exitFullscreen();
35
+ }, [calibration?.active]);
36
+
37
+ useEffect(() => {
38
+ if (!calibration || !calibration.collecting || calibration.done) {
39
+ setProgress(0);
40
+ if (timerRef.current) cancelAnimationFrame(timerRef.current);
41
+ return;
42
+ }
43
+
44
+ startRef.current = performance.now();
45
+ const duration = calibration.index === 0 ? CENTER_MS : COLLECT_MS;
46
+
47
+ const tick = () => {
48
+ const pct = Math.min((performance.now() - startRef.current) / duration, 1);
49
+ setProgress(pct);
50
+ if (pct >= 1) {
51
+ if (videoManager) videoManager.nextCalibrationPoint();
52
+ startRef.current = performance.now();
53
+ setProgress(0);
54
+ }
55
+ timerRef.current = requestAnimationFrame(tick);
56
+ };
57
+ timerRef.current = requestAnimationFrame(tick);
58
+
59
+ return () => { if (timerRef.current) cancelAnimationFrame(timerRef.current); };
60
+ }, [calibration?.index, calibration?.collecting, calibration?.done]);
61
+
62
+ const handleCancel = () => {
63
+ if (videoManager) videoManager.cancelCalibration();
64
+ exitFullscreen();
65
+ };
66
+
67
+ if (!calibration || !calibration.active) return null;
68
+
69
+ if (calibration.done) {
70
+ return (
71
+ <div ref={overlayRef} style={overlayStyle}>
72
+ <div style={messageBoxStyle}>
73
+ <h2 style={{ margin: '0 0 10px', color: calibration.success ? '#4ade80' : '#f87171' }}>
74
+ {calibration.success ? 'Calibration Complete' : 'Calibration Failed'}
75
+ </h2>
76
+ <p style={{ color: '#ccc', margin: 0 }}>
77
+ {calibration.success
78
+ ? 'Gaze tracking is now active.'
79
+ : 'Not enough samples collected. Try again.'}
80
+ </p>
81
+ </div>
82
+ </div>
83
+ );
84
+ }
85
+
86
+ const [tx, ty] = calibration.target || [0.5, 0.5];
87
+
88
+ return (
89
+ <div ref={overlayRef} style={overlayStyle}>
90
+ <div style={{
91
+ position: 'absolute', top: '30px', left: '50%', transform: 'translateX(-50%)',
92
+ color: '#fff', fontSize: '16px', textAlign: 'center',
93
+ textShadow: '0 0 8px rgba(0,0,0,0.8)', pointerEvents: 'none',
94
+ }}>
95
+ <div style={{ fontWeight: 'bold', fontSize: '20px' }}>
96
+ Look at the dot ({calibration.index + 1}/{calibration.numPoints})
97
+ </div>
98
+ <div style={{ fontSize: '14px', color: '#aaa', marginTop: '6px' }}>
99
+ {calibration.index === 0
100
+ ? 'Look at the center dot - this sets your baseline'
101
+ : 'Hold your gaze steady on the target'}
102
+ </div>
103
+ </div>
104
+
105
+ <div style={{
106
+ position: 'absolute', left: `${tx * 100}%`, top: `${ty * 100}%`,
107
+ transform: 'translate(-50%, -50%)',
108
+ }}>
109
+ <svg width="60" height="60" style={{ position: 'absolute', left: '-30px', top: '-30px' }}>
110
+ <circle cx="30" cy="30" r="24" fill="none" stroke="rgba(255,255,255,0.15)" strokeWidth="3" />
111
+ <circle cx="30" cy="30" r="24" fill="none" stroke="#4ade80" strokeWidth="3"
112
+ strokeDasharray={`${progress * 150.8} 150.8`} strokeLinecap="round"
113
+ transform="rotate(-90, 30, 30)" />
114
+ </svg>
115
+ <div style={{
116
+ width: '20px', height: '20px', borderRadius: '50%',
117
+ background: 'radial-gradient(circle, #fff 30%, #4ade80 100%)',
118
+ boxShadow: '0 0 20px rgba(74, 222, 128, 0.8)',
119
+ }} />
120
+ </div>
121
+
122
+ <button onClick={handleCancel} style={{
123
+ position: 'absolute', bottom: '40px', left: '50%', transform: 'translateX(-50%)',
124
+ padding: '10px 28px', background: 'rgba(255,255,255,0.1)',
125
+ border: '1px solid rgba(255,255,255,0.3)', color: '#fff',
126
+ borderRadius: '20px', cursor: 'pointer', fontSize: '14px',
127
+ }}>
128
+ Cancel Calibration
129
+ </button>
130
+ </div>
131
+ );
132
+ }
133
+
134
+ const overlayStyle = {
135
+ position: 'fixed', top: 0, left: 0, width: '100vw', height: '100vh',
136
+ background: 'rgba(0, 0, 0, 0.92)', zIndex: 10000,
137
+ display: 'flex', alignItems: 'center', justifyContent: 'center',
138
+ };
139
+
140
+ const messageBoxStyle = {
141
+ textAlign: 'center', padding: '30px 40px',
142
+ background: 'rgba(30, 30, 50, 0.9)', borderRadius: '16px',
143
+ border: '1px solid rgba(255,255,255,0.1)',
144
+ };
145
+
146
+ export default CalibrationOverlay;
src/components/FocusPageLocal.jsx CHANGED
@@ -1,4 +1,5 @@
1
  import React, { useState, useEffect, useRef } from 'react';
 
2
 
3
  function FocusPageLocal({ videoManager, sessionResult, setSessionResult, isActive }) {
4
  const [currentFrame, setCurrentFrame] = useState(15);
@@ -6,6 +7,9 @@ function FocusPageLocal({ videoManager, sessionResult, setSessionResult, isActiv
6
  const [stats, setStats] = useState(null);
7
  const [availableModels, setAvailableModels] = useState([]);
8
  const [currentModel, setCurrentModel] = useState('mlp');
 
 
 
9
 
10
  const localVideoRef = useRef(null);
11
  const displayCanvasRef = useRef(null);
@@ -23,7 +27,6 @@ function FocusPageLocal({ videoManager, sessionResult, setSessionResult, isActiv
23
  useEffect(() => {
24
  if (!videoManager) return;
25
 
26
- // 设置回调函数来更新时间轴
27
  const originalOnStatusUpdate = videoManager.callbacks.onStatusUpdate;
28
  videoManager.callbacks.onStatusUpdate = (isFocused) => {
29
  setTimelineEvents(prev => {
@@ -34,7 +37,10 @@ function FocusPageLocal({ videoManager, sessionResult, setSessionResult, isActiv
34
  if (originalOnStatusUpdate) originalOnStatusUpdate(isFocused);
35
  };
36
 
37
- // 定期更新统计信息
 
 
 
38
  const statsInterval = setInterval(() => {
39
  if (videoManager && videoManager.getStats) {
40
  setStats(videoManager.getStats());
@@ -44,6 +50,7 @@ function FocusPageLocal({ videoManager, sessionResult, setSessionResult, isActiv
44
  return () => {
45
  if (videoManager) {
46
  videoManager.callbacks.onStatusUpdate = originalOnStatusUpdate;
 
47
  }
48
  clearInterval(statsInterval);
49
  };
@@ -56,6 +63,8 @@ function FocusPageLocal({ videoManager, sessionResult, setSessionResult, isActiv
56
  .then(data => {
57
  if (data.available) setAvailableModels(data.available);
58
  if (data.current) setCurrentModel(data.current);
 
 
59
  })
60
  .catch(err => console.error('Failed to fetch models:', err));
61
  }, []);
@@ -70,12 +79,28 @@ function FocusPageLocal({ videoManager, sessionResult, setSessionResult, isActiv
70
  const result = await res.json();
71
  if (result.updated) {
72
  setCurrentModel(modelName);
 
 
73
  }
74
  } catch (err) {
75
  console.error('Failed to switch model:', err);
76
  }
77
  };
78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  const handleStart = async () => {
80
  try {
81
  if (videoManager) {
@@ -443,6 +468,44 @@ function FocusPageLocal({ videoManager, sessionResult, setSessionResult, isActiv
443
  {name}
444
  </button>
445
  ))}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
446
  </section>
447
  )}
448
 
@@ -513,6 +576,9 @@ function FocusPageLocal({ videoManager, sessionResult, setSessionResult, isActiv
513
  onChange={(e) => handleFrameChange(e.target.value)}
514
  />
515
  </section>
 
 
 
516
  </main>
517
  );
518
  }
 
1
  import React, { useState, useEffect, useRef } from 'react';
2
+ import CalibrationOverlay from './CalibrationOverlay';
3
 
4
  function FocusPageLocal({ videoManager, sessionResult, setSessionResult, isActive }) {
5
  const [currentFrame, setCurrentFrame] = useState(15);
 
7
  const [stats, setStats] = useState(null);
8
  const [availableModels, setAvailableModels] = useState([]);
9
  const [currentModel, setCurrentModel] = useState('mlp');
10
+ const [calibration, setCalibration] = useState(null);
11
+ const [l2csBoost, setL2csBoost] = useState(false);
12
+ const [l2csBoostAvailable, setL2csBoostAvailable] = useState(false);
13
 
14
  const localVideoRef = useRef(null);
15
  const displayCanvasRef = useRef(null);
 
27
  useEffect(() => {
28
  if (!videoManager) return;
29
 
 
30
  const originalOnStatusUpdate = videoManager.callbacks.onStatusUpdate;
31
  videoManager.callbacks.onStatusUpdate = (isFocused) => {
32
  setTimelineEvents(prev => {
 
37
  if (originalOnStatusUpdate) originalOnStatusUpdate(isFocused);
38
  };
39
 
40
+ videoManager.callbacks.onCalibrationUpdate = (cal) => {
41
+ setCalibration(cal && cal.active ? { ...cal } : null);
42
+ };
43
+
44
  const statsInterval = setInterval(() => {
45
  if (videoManager && videoManager.getStats) {
46
  setStats(videoManager.getStats());
 
50
  return () => {
51
  if (videoManager) {
52
  videoManager.callbacks.onStatusUpdate = originalOnStatusUpdate;
53
+ videoManager.callbacks.onCalibrationUpdate = null;
54
  }
55
  clearInterval(statsInterval);
56
  };
 
63
  .then(data => {
64
  if (data.available) setAvailableModels(data.available);
65
  if (data.current) setCurrentModel(data.current);
66
+ if (data.l2cs_boost !== undefined) setL2csBoost(data.l2cs_boost);
67
+ if (data.l2cs_boost_available !== undefined) setL2csBoostAvailable(data.l2cs_boost_available);
68
  })
69
  .catch(err => console.error('Failed to fetch models:', err));
70
  }, []);
 
79
  const result = await res.json();
80
  if (result.updated) {
81
  setCurrentModel(modelName);
82
+ setL2csBoostAvailable(modelName !== 'l2cs' && availableModels.includes('l2cs'));
83
+ if (modelName === 'l2cs') setL2csBoost(false);
84
  }
85
  } catch (err) {
86
  console.error('Failed to switch model:', err);
87
  }
88
  };
89
 
90
+ const handleBoostToggle = async () => {
91
+ const next = !l2csBoost;
92
+ try {
93
+ const res = await fetch('/api/settings', {
94
+ method: 'PUT',
95
+ headers: { 'Content-Type': 'application/json' },
96
+ body: JSON.stringify({ l2cs_boost: next })
97
+ });
98
+ if (res.ok) setL2csBoost(next);
99
+ } catch (err) {
100
+ console.error('Failed to toggle L2CS boost:', err);
101
+ }
102
+ };
103
+
104
  const handleStart = async () => {
105
  try {
106
  if (videoManager) {
 
468
  {name}
469
  </button>
470
  ))}
471
+ {l2csBoostAvailable && currentModel !== 'l2cs' && (
472
+ <button
473
+ onClick={handleBoostToggle}
474
+ style={{
475
+ padding: '5px 14px',
476
+ borderRadius: '16px',
477
+ border: l2csBoost ? '2px solid #f59e0b' : '1px solid #555',
478
+ background: l2csBoost ? 'rgba(245, 158, 11, 0.15)' : 'transparent',
479
+ color: l2csBoost ? '#f59e0b' : '#888',
480
+ fontSize: '11px',
481
+ fontWeight: l2csBoost ? 'bold' : 'normal',
482
+ cursor: 'pointer',
483
+ transition: 'all 0.2s',
484
+ marginLeft: '4px',
485
+ }}
486
+ >
487
+ {l2csBoost ? 'GAZE ON' : 'GAZE'}
488
+ </button>
489
+ )}
490
+ {(currentModel === 'l2cs' || l2csBoost) && stats && stats.isStreaming && (
491
+ <button
492
+ onClick={() => videoManager && videoManager.startCalibration()}
493
+ style={{
494
+ padding: '5px 14px',
495
+ borderRadius: '16px',
496
+ border: '1px solid #4ade80',
497
+ background: 'transparent',
498
+ color: '#4ade80',
499
+ fontSize: '12px',
500
+ fontWeight: 'bold',
501
+ cursor: 'pointer',
502
+ transition: 'all 0.2s',
503
+ marginLeft: '4px',
504
+ }}
505
+ >
506
+ Calibrate
507
+ </button>
508
+ )}
509
  </section>
510
  )}
511
 
 
576
  onChange={(e) => handleFrameChange(e.target.value)}
577
  />
578
  </section>
579
+
580
+ {/* Calibration overlay (fixed fullscreen, must be outside overflow:hidden containers) */}
581
+ <CalibrationOverlay calibration={calibration} videoManager={videoManager} />
582
  </main>
583
  );
584
  }
src/utils/VideoManagerLocal.js CHANGED
@@ -39,6 +39,17 @@ export class VideoManagerLocal {
39
  this.lastNotificationTime = null;
40
  this.notificationCooldown = 60000;
41
 
 
 
 
 
 
 
 
 
 
 
 
42
  // 性能统计
43
  this.stats = {
44
  framesSent: 0,
@@ -73,8 +84,8 @@ export class VideoManagerLocal {
73
 
74
  // 创建用于截图的 canvas (smaller for faster encode + transfer)
75
  this.canvas = document.createElement('canvas');
76
- this.canvas.width = 320;
77
- this.canvas.height = 240;
78
 
79
  console.log('Local camera initialized');
80
  return true;
@@ -188,7 +199,7 @@ export class VideoManagerLocal {
188
  this.ws.send(blob);
189
  this.stats.framesSent++;
190
  }
191
- }, 'image/jpeg', 0.5);
192
  } catch (error) {
193
  this._sendingBlob = false;
194
  console.error('Capture error:', error);
@@ -253,6 +264,19 @@ export class VideoManagerLocal {
253
  ctx.textAlign = 'left';
254
  }
255
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
256
  // Performance stats
257
  ctx.fillStyle = 'rgba(0,0,0,0.5)';
258
  ctx.fillRect(0, h - 25, w, 25);
@@ -321,6 +345,9 @@ export class VideoManagerLocal {
321
  mar: data.mar,
322
  sf: data.sf,
323
  se: data.se,
 
 
 
324
  };
325
  this.drawDetectionResult(detectionData);
326
  break;
@@ -338,6 +365,51 @@ export class VideoManagerLocal {
338
  this.sessionStartTime = null;
339
  break;
340
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
341
  case 'error':
342
  console.error('Server error:', data.message);
343
  break;
@@ -347,6 +419,28 @@ export class VideoManagerLocal {
347
  }
348
  }
349
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
350
  // Face mesh landmark index groups (matches live_demo.py)
351
  static FACE_OVAL = [10,338,297,332,284,251,389,356,454,323,361,288,397,365,379,378,400,377,152,148,176,149,150,136,172,58,132,93,234,127,162,21,54,103,67,109,10];
352
  static LEFT_EYE = [33,7,163,144,145,153,154,155,133,173,157,158,159,160,161,246];
 
39
  this.lastNotificationTime = null;
40
  this.notificationCooldown = 60000;
41
 
42
+ // Calibration state
43
+ this.calibration = {
44
+ active: false,
45
+ collecting: false,
46
+ target: null,
47
+ index: 0,
48
+ numPoints: 0,
49
+ done: false,
50
+ success: false,
51
+ };
52
+
53
  // 性能统计
54
  this.stats = {
55
  framesSent: 0,
 
84
 
85
  // 创建用于截图的 canvas (smaller for faster encode + transfer)
86
  this.canvas = document.createElement('canvas');
87
+ this.canvas.width = 640;
88
+ this.canvas.height = 480;
89
 
90
  console.log('Local camera initialized');
91
  return true;
 
199
  this.ws.send(blob);
200
  this.stats.framesSent++;
201
  }
202
+ }, 'image/jpeg', 0.75);
203
  } catch (error) {
204
  this._sendingBlob = false;
205
  console.error('Capture error:', error);
 
264
  ctx.textAlign = 'left';
265
  }
266
  }
267
+ // Gaze pointer (L2CS + calibration)
268
+ if (data && data.gaze_x !== undefined && data.gaze_y !== undefined) {
269
+ const gx = data.gaze_x * w;
270
+ const gy = data.gaze_y * h;
271
+ ctx.beginPath();
272
+ ctx.arc(gx, gy, 8, 0, 2 * Math.PI);
273
+ ctx.fillStyle = data.on_screen ? 'rgba(0, 200, 255, 0.7)' : 'rgba(255, 80, 80, 0.5)';
274
+ ctx.fill();
275
+ ctx.strokeStyle = '#FFFFFF';
276
+ ctx.lineWidth = 2;
277
+ ctx.stroke();
278
+ }
279
+
280
  // Performance stats
281
  ctx.fillStyle = 'rgba(0,0,0,0.5)';
282
  ctx.fillRect(0, h - 25, w, 25);
 
345
  mar: data.mar,
346
  sf: data.sf,
347
  se: data.se,
348
+ gaze_x: data.gaze_x,
349
+ gaze_y: data.gaze_y,
350
+ on_screen: data.on_screen,
351
  };
352
  this.drawDetectionResult(detectionData);
353
  break;
 
365
  this.sessionStartTime = null;
366
  break;
367
 
368
+ case 'calibration_started':
369
+ this.calibration = {
370
+ active: true,
371
+ collecting: true,
372
+ target: data.target,
373
+ index: data.index,
374
+ numPoints: data.num_points,
375
+ done: false,
376
+ success: false,
377
+ };
378
+ if (this.callbacks.onCalibrationUpdate) {
379
+ this.callbacks.onCalibrationUpdate({ ...this.calibration });
380
+ }
381
+ break;
382
+
383
+ case 'calibration_point':
384
+ this.calibration.target = data.target;
385
+ this.calibration.index = data.index;
386
+ if (this.callbacks.onCalibrationUpdate) {
387
+ this.callbacks.onCalibrationUpdate({ ...this.calibration });
388
+ }
389
+ break;
390
+
391
+ case 'calibration_done':
392
+ this.calibration.collecting = false;
393
+ this.calibration.done = true;
394
+ this.calibration.success = data.success;
395
+ if (this.callbacks.onCalibrationUpdate) {
396
+ this.callbacks.onCalibrationUpdate({ ...this.calibration });
397
+ }
398
+ setTimeout(() => {
399
+ this.calibration.active = false;
400
+ if (this.callbacks.onCalibrationUpdate) {
401
+ this.callbacks.onCalibrationUpdate({ ...this.calibration });
402
+ }
403
+ }, 2000);
404
+ break;
405
+
406
+ case 'calibration_cancelled':
407
+ this.calibration = { active: false, collecting: false, target: null, index: 0, numPoints: 0, done: false, success: false };
408
+ if (this.callbacks.onCalibrationUpdate) {
409
+ this.callbacks.onCalibrationUpdate({ ...this.calibration });
410
+ }
411
+ break;
412
+
413
  case 'error':
414
  console.error('Server error:', data.message);
415
  break;
 
419
  }
420
  }
421
 
422
+ startCalibration() {
423
+ if (this.ws && this.ws.readyState === WebSocket.OPEN) {
424
+ this.ws.send(JSON.stringify({ type: 'calibration_start' }));
425
+ }
426
+ }
427
+
428
+ nextCalibrationPoint() {
429
+ if (this.ws && this.ws.readyState === WebSocket.OPEN) {
430
+ this.ws.send(JSON.stringify({ type: 'calibration_next' }));
431
+ }
432
+ }
433
+
434
+ cancelCalibration() {
435
+ if (this.ws && this.ws.readyState === WebSocket.OPEN) {
436
+ this.ws.send(JSON.stringify({ type: 'calibration_cancel' }));
437
+ }
438
+ this.calibration = { active: false, collecting: false, target: null, index: 0, numPoints: 0, done: false, success: false };
439
+ if (this.callbacks.onCalibrationUpdate) {
440
+ this.callbacks.onCalibrationUpdate({ ...this.calibration });
441
+ }
442
+ }
443
+
444
  // Face mesh landmark index groups (matches live_demo.py)
445
  static FACE_OVAL = [10,338,297,332,284,251,389,356,454,323,361,288,397,365,379,378,400,377,152,148,176,149,150,136,172,58,132,93,234,127,162,21,54,103,67,109,10];
446
  static LEFT_EYE = [33,7,163,144,145,153,154,155,133,173,157,158,159,160,161,246];
ui/pipeline.py CHANGED
@@ -3,6 +3,7 @@ import glob
3
  import json
4
  import math
5
  import os
 
6
  import sys
7
 
8
  import numpy as np
@@ -49,10 +50,12 @@ def _clip_features(vec):
49
 
50
 
51
  class _OutputSmoother:
52
- """EMA smoothing on focus score with no-face grace period."""
 
53
 
54
- def __init__(self, alpha: float = 0.3, grace_frames: int = 15):
55
- self._alpha = alpha
 
56
  self._grace = grace_frames
57
  self._score = 0.5
58
  self._no_face = 0
@@ -61,14 +64,15 @@ class _OutputSmoother:
61
  self._score = 0.5
62
  self._no_face = 0
63
 
64
- def update(self, raw_score: float, face_detected: bool) -> float:
65
  if face_detected:
66
  self._no_face = 0
67
- self._score += self._alpha * (raw_score - self._score)
 
68
  else:
69
  self._no_face += 1
70
  if self._no_face > self._grace:
71
- self._score *= 0.85
72
  return self._score
73
 
74
 
@@ -640,3 +644,141 @@ class XGBoostPipeline:
640
 
641
  def __exit__(self, *args):
642
  self.close()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import json
4
  import math
5
  import os
6
+ import pathlib
7
  import sys
8
 
9
  import numpy as np
 
50
 
51
 
52
  class _OutputSmoother:
53
+ # Asymmetric EMA: rises fast (recognise focus), falls slower (avoid flicker).
54
+ # Grace period holds score steady for a few frames when face is lost.
55
 
56
+ def __init__(self, alpha_up=0.55, alpha_down=0.45, grace_frames=10):
57
+ self._alpha_up = alpha_up
58
+ self._alpha_down = alpha_down
59
  self._grace = grace_frames
60
  self._score = 0.5
61
  self._no_face = 0
 
64
  self._score = 0.5
65
  self._no_face = 0
66
 
67
+ def update(self, raw_score, face_detected):
68
  if face_detected:
69
  self._no_face = 0
70
+ alpha = self._alpha_up if raw_score > self._score else self._alpha_down
71
+ self._score += alpha * (raw_score - self._score)
72
  else:
73
  self._no_face += 1
74
  if self._no_face > self._grace:
75
+ self._score *= 0.80
76
  return self._score
77
 
78
 
 
644
 
645
  def __exit__(self, *args):
646
  self.close()
647
+
648
+
649
+ def _resolve_l2cs_weights():
650
+ for p in [
651
+ os.path.join(_PROJECT_ROOT, "models", "L2CS-Net", "models", "L2CSNet_gaze360.pkl"),
652
+ os.path.join(_PROJECT_ROOT, "models", "L2CSNet_gaze360.pkl"),
653
+ os.path.join(_PROJECT_ROOT, "checkpoints", "L2CSNet_gaze360.pkl"),
654
+ ]:
655
+ if os.path.isfile(p):
656
+ return p
657
+ return None
658
+
659
+
660
+ def is_l2cs_weights_available():
661
+ return _resolve_l2cs_weights() is not None
662
+
663
+
664
+ class L2CSPipeline:
665
+ # Uses in-tree l2cs.Pipeline (RetinaFace + ResNet50) for gaze estimation
666
+ # and MediaPipe for head pose, EAR, MAR, and roll de-rotation.
667
+
668
+ YAW_THRESHOLD = 22.0
669
+ PITCH_THRESHOLD = 20.0
670
+
671
+ def __init__(self, weights_path=None, arch="ResNet50", device="cpu",
672
+ threshold=0.52, detector=None):
673
+ resolved = weights_path or _resolve_l2cs_weights()
674
+ if resolved is None or not os.path.isfile(resolved):
675
+ raise FileNotFoundError(
676
+ "L2CS weights not found. Place L2CSNet_gaze360.pkl in "
677
+ "models/L2CS-Net/models/ or checkpoints/"
678
+ )
679
+
680
+ # add in-tree L2CS-Net to import path
681
+ l2cs_root = os.path.join(_PROJECT_ROOT, "models", "L2CS-Net")
682
+ if l2cs_root not in sys.path:
683
+ sys.path.insert(0, l2cs_root)
684
+ from l2cs import Pipeline as _L2CSPipeline
685
+
686
+ import torch
687
+ # bypass upstream select_device bug by constructing torch.device directly
688
+ self._pipeline = _L2CSPipeline(
689
+ weights=pathlib.Path(resolved), arch=arch, device=torch.device(device),
690
+ )
691
+
692
+ self._detector = detector or FaceMeshDetector()
693
+ self._owns_detector = detector is None
694
+ self._head_pose = HeadPoseEstimator()
695
+ self.head_pose = self._head_pose
696
+ self._eye_scorer = EyeBehaviourScorer()
697
+ self._threshold = threshold
698
+ self._smoother = _OutputSmoother()
699
+
700
+ print(
701
+ f"[L2CS] Loaded {resolved} | arch={arch} device={device} "
702
+ f"yaw_thresh={self.YAW_THRESHOLD} pitch_thresh={self.PITCH_THRESHOLD} "
703
+ f"threshold={threshold}"
704
+ )
705
+
706
+ @staticmethod
707
+ def _derotate_gaze(pitch_rad, yaw_rad, roll_deg):
708
+ # remove head roll so tilted-but-looking-at-screen reads as (0,0)
709
+ roll_rad = -math.radians(roll_deg)
710
+ cos_r, sin_r = math.cos(roll_rad), math.sin(roll_rad)
711
+ return (yaw_rad * sin_r + pitch_rad * cos_r,
712
+ yaw_rad * cos_r - pitch_rad * sin_r)
713
+
714
+ def process_frame(self, bgr_frame):
715
+ landmarks = self._detector.process(bgr_frame)
716
+ h, w = bgr_frame.shape[:2]
717
+
718
+ out = {
719
+ "landmarks": landmarks, "is_focused": False, "raw_score": 0.0,
720
+ "s_face": 0.0, "s_eye": 0.0, "gaze_pitch": None, "gaze_yaw": None,
721
+ "yaw": None, "pitch": None, "roll": None, "mar": None, "is_yawning": False,
722
+ }
723
+
724
+ # MediaPipe: head pose, eye/mouth scores
725
+ roll_deg = 0.0
726
+ if landmarks is not None:
727
+ angles = self._head_pose.estimate(landmarks, w, h)
728
+ if angles is not None:
729
+ out["yaw"], out["pitch"], out["roll"] = angles
730
+ roll_deg = angles[2]
731
+ out["s_face"] = self._head_pose.score(landmarks, w, h)
732
+ out["s_eye"] = self._eye_scorer.score(landmarks)
733
+ out["mar"] = compute_mar(landmarks)
734
+ out["is_yawning"] = out["mar"] > MAR_YAWN_THRESHOLD
735
+
736
+ # L2CS gaze (uses its own RetinaFace detector internally)
737
+ results = self._pipeline.step(bgr_frame)
738
+
739
+ if results is None or results.pitch.shape[0] == 0:
740
+ smoothed = self._smoother.update(0.0, landmarks is not None)
741
+ out["raw_score"] = smoothed
742
+ out["is_focused"] = smoothed >= self._threshold
743
+ return out
744
+
745
+ pitch_rad = float(results.pitch[0])
746
+ yaw_rad = float(results.yaw[0])
747
+
748
+ pitch_rad, yaw_rad = self._derotate_gaze(pitch_rad, yaw_rad, roll_deg)
749
+ out["gaze_pitch"] = pitch_rad
750
+ out["gaze_yaw"] = yaw_rad
751
+
752
+ yaw_deg = abs(math.degrees(yaw_rad))
753
+ pitch_deg = abs(math.degrees(pitch_rad))
754
+
755
+ # fall back to L2CS angles if MediaPipe didn't produce head pose
756
+ out["yaw"] = out.get("yaw") or math.degrees(yaw_rad)
757
+ out["pitch"] = out.get("pitch") or math.degrees(pitch_rad)
758
+
759
+ # cosine scoring: 1.0 at centre, 0.0 at threshold
760
+ yaw_t = min(yaw_deg / self.YAW_THRESHOLD, 1.0)
761
+ pitch_t = min(pitch_deg / self.PITCH_THRESHOLD, 1.0)
762
+ yaw_score = 0.5 * (1.0 + math.cos(math.pi * yaw_t))
763
+ pitch_score = 0.5 * (1.0 + math.cos(math.pi * pitch_t))
764
+ gaze_score = 0.55 * yaw_score + 0.45 * pitch_score
765
+
766
+ if out["is_yawning"]:
767
+ gaze_score = 0.0
768
+
769
+ out["raw_score"] = self._smoother.update(float(gaze_score), True)
770
+ out["is_focused"] = out["raw_score"] >= self._threshold
771
+ return out
772
+
773
+ def reset_session(self):
774
+ self._smoother.reset()
775
+
776
+ def close(self):
777
+ if self._owns_detector:
778
+ self._detector.close()
779
+
780
+ def __enter__(self):
781
+ return self
782
+
783
+ def __exit__(self, *args):
784
+ self.close()