Spaces:
Sleeping
Sleeping
Abdelrahman Almatrooshi commited on
Commit ·
2eba0cc
1
Parent(s): fad97ce
Integrate L2CS-Net gaze estimation
Browse files- Add L2CS-Net in-tree (models/L2CS-Net/) with Gaze360 weights via Git LFS
- L2CSPipeline: ResNet50 gaze + MediaPipe head pose, roll de-rotation, cosine scoring
- 9-point polynomial gaze calibration with bias correction and IQR outlier filtering
- Gaze-eye fusion: calibrated screen coords + EAR for focus detection
- L2CS Boost mode: runs gaze alongside any base model (35/65 weight, veto at 0.38)
- Calibration UI: fullscreen overlay, auto-advance, progress ring
- Frontend: GAZE toggle, Calibrate button, gaze pointer dot on canvas
- Bumped capture resolution to 640x480 @ JPEG 0.75
- Dockerfile: added git, CPU-only torch for HF Space deployment
- Dockerfile +8 -1
- README.md +87 -3
- download_l2cs_weights.py +37 -0
- main.py +273 -26
- models/L2CS-Net/.gitignore +140 -0
- models/L2CS-Net/LICENSE +21 -0
- models/L2CS-Net/README.md +148 -0
- models/L2CS-Net/demo.py +87 -0
- models/L2CS-Net/l2cs/__init__.py +21 -0
- models/L2CS-Net/l2cs/datasets.py +157 -0
- models/L2CS-Net/l2cs/model.py +73 -0
- models/L2CS-Net/l2cs/pipeline.py +133 -0
- models/L2CS-Net/l2cs/results.py +11 -0
- models/L2CS-Net/l2cs/utils.py +145 -0
- models/L2CS-Net/l2cs/vis.py +64 -0
- models/L2CS-Net/leave_one_out_eval.py +54 -0
- models/L2CS-Net/models/L2CSNet_gaze360.pkl +3 -0
- models/L2CS-Net/models/README.md +1 -0
- models/L2CS-Net/pyproject.toml +44 -0
- models/L2CS-Net/test.py +284 -0
- models/L2CS-Net/train.py +384 -0
- models/gaze_calibration.py +146 -0
- models/gaze_eye_fusion.py +66 -0
- requirements.txt +4 -0
- src/components/CalibrationOverlay.jsx +146 -0
- src/components/FocusPageLocal.jsx +68 -2
- src/utils/VideoManagerLocal.js +97 -3
- ui/pipeline.py +148 -6
Dockerfile
CHANGED
|
@@ -7,7 +7,14 @@ ENV PYTHONUNBUFFERED=1
|
|
| 7 |
|
| 8 |
WORKDIR /app
|
| 9 |
|
| 10 |
-
RUN apt-get update && apt-get install -y --no-install-recommends
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
COPY requirements.txt ./
|
| 13 |
RUN pip install --no-cache-dir -r requirements.txt
|
|
|
|
| 7 |
|
| 8 |
WORKDIR /app
|
| 9 |
|
| 10 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 11 |
+
libglib2.0-0 libsm6 libxrender1 libxext6 libxcb1 libgl1 libgomp1 \
|
| 12 |
+
ffmpeg libavcodec-dev libavformat-dev libavutil-dev libswscale-dev \
|
| 13 |
+
libavdevice-dev libopus-dev libvpx-dev libsrtp2-dev \
|
| 14 |
+
build-essential nodejs npm git \
|
| 15 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 16 |
+
|
| 17 |
+
RUN pip install --no-cache-dir torch torchvision --index-url https://download.pytorch.org/whl/cpu
|
| 18 |
|
| 19 |
COPY requirements.txt ./
|
| 20 |
RUN pip install --no-cache-dir -r requirements.txt
|
README.md
CHANGED
|
@@ -1,10 +1,94 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji: 📚
|
| 4 |
colorFrom: indigo
|
| 5 |
colorTo: purple
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
| 8 |
---
|
| 9 |
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: FocusGuard
|
|
|
|
| 3 |
colorFrom: indigo
|
| 4 |
colorTo: purple
|
| 5 |
sdk: docker
|
| 6 |
pinned: false
|
| 7 |
---
|
| 8 |
|
| 9 |
+
# FocusGuard - Real-Time Focus Detection
|
| 10 |
+
|
| 11 |
+
A web app that monitors whether you're focused on your screen using your webcam. Combines head pose estimation, eye behaviour analysis, and deep learning gaze tracking to detect attention in real time.
|
| 12 |
+
|
| 13 |
+
## How It Works
|
| 14 |
+
|
| 15 |
+
1. **Open the app** and click **Start** - your webcam feed appears with a face mesh overlay.
|
| 16 |
+
2. **Pick a model** from the selector bar (Geometric, XGBoost, L2CS, etc.).
|
| 17 |
+
3. The system analyses each frame and shows **FOCUSED** or **NOT FOCUSED** with a confidence score.
|
| 18 |
+
4. A timeline tracks your focus over time. Session history is saved for review.
|
| 19 |
+
|
| 20 |
+
## Models
|
| 21 |
+
|
| 22 |
+
| Model | What it uses | Best for |
|
| 23 |
+
|-------|-------------|----------|
|
| 24 |
+
| **Geometric** | Head pose angles + eye aspect ratio (EAR) | Fast, no ML needed |
|
| 25 |
+
| **XGBoost** | Trained classifier on head/eye features | Balanced accuracy/speed |
|
| 26 |
+
| **MLP** | Neural network on same features | Higher accuracy |
|
| 27 |
+
| **Hybrid** | Weighted MLP + Geometric ensemble | Best head-pose accuracy |
|
| 28 |
+
| **L2CS** | Deep gaze estimation (ResNet50) | Detects eye-only gaze shifts |
|
| 29 |
+
|
| 30 |
+
## L2CS Gaze Tracking
|
| 31 |
+
|
| 32 |
+
L2CS-Net predicts where your eyes are looking, not just where your head is pointed. This catches the scenario where your head faces the screen but your eyes wander.
|
| 33 |
+
|
| 34 |
+
### Standalone mode
|
| 35 |
+
Select **L2CS** as the model - it handles everything.
|
| 36 |
+
|
| 37 |
+
### Boost mode
|
| 38 |
+
Select any other model, then click the **GAZE** toggle. L2CS runs alongside the base model:
|
| 39 |
+
- Base model handles head pose and eye openness (35% weight)
|
| 40 |
+
- L2CS handles gaze direction (65% weight)
|
| 41 |
+
- If L2CS detects gaze is clearly off-screen, it **vetoes** the base model regardless of score
|
| 42 |
+
|
| 43 |
+
### Calibration
|
| 44 |
+
After enabling L2CS or Gaze Boost, click **Calibrate** while a session is running:
|
| 45 |
+
1. A fullscreen overlay shows 9 target dots (3x3 grid)
|
| 46 |
+
2. Look at each dot as the progress ring fills
|
| 47 |
+
3. The first dot (centre) sets your baseline gaze offset
|
| 48 |
+
4. After all 9 points, a polynomial model maps your gaze angles to screen coordinates
|
| 49 |
+
5. A cyan tracking dot appears on the video showing where you're looking
|
| 50 |
+
|
| 51 |
+
## Tech Stack
|
| 52 |
+
|
| 53 |
+
- **Backend**: FastAPI + WebSocket, Python 3.10
|
| 54 |
+
- **Frontend**: React + Vite
|
| 55 |
+
- **Face detection**: MediaPipe Face Landmarker (478 landmarks)
|
| 56 |
+
- **Gaze estimation**: L2CS-Net (ResNet50, Gaze360 weights)
|
| 57 |
+
- **ML models**: XGBoost, PyTorch MLP
|
| 58 |
+
- **Deployment**: Docker on Hugging Face Spaces
|
| 59 |
+
|
| 60 |
+
## Running Locally
|
| 61 |
+
|
| 62 |
+
```bash
|
| 63 |
+
# install Python deps
|
| 64 |
+
pip install -r requirements.txt
|
| 65 |
+
|
| 66 |
+
# install frontend deps and build
|
| 67 |
+
npm install && npm run build
|
| 68 |
+
|
| 69 |
+
# start the server
|
| 70 |
+
uvicorn main:app --port 8000
|
| 71 |
+
```
|
| 72 |
+
|
| 73 |
+
Open `http://localhost:8000` in your browser.
|
| 74 |
+
|
| 75 |
+
## Project Structure
|
| 76 |
+
|
| 77 |
+
```
|
| 78 |
+
main.py # FastAPI app, WebSocket handler, API endpoints
|
| 79 |
+
ui/pipeline.py # All focus detection pipelines (Geometric, MLP, XGBoost, Hybrid, L2CS)
|
| 80 |
+
models/
|
| 81 |
+
face_mesh.py # MediaPipe face landmark detector
|
| 82 |
+
head_pose.py # Head pose estimation from landmarks
|
| 83 |
+
eye_scorer.py # EAR/eye behaviour scoring
|
| 84 |
+
gaze_calibration.py # 9-point polynomial gaze calibration
|
| 85 |
+
gaze_eye_fusion.py # Fuses calibrated gaze with eye openness
|
| 86 |
+
L2CS-Net/ # In-tree L2CS-Net repo with Gaze360 weights
|
| 87 |
+
src/
|
| 88 |
+
components/
|
| 89 |
+
FocusPageLocal.jsx # Main focus page (camera, controls, model selector)
|
| 90 |
+
CalibrationOverlay.jsx # Fullscreen calibration UI
|
| 91 |
+
utils/
|
| 92 |
+
VideoManagerLocal.js # WebSocket client, frame capture, canvas rendering
|
| 93 |
+
Dockerfile # Docker build for HF Spaces
|
| 94 |
+
```
|
download_l2cs_weights.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Downloads L2CS-Net Gaze360 weights into checkpoints/
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
import sys
|
| 6 |
+
|
| 7 |
+
CHECKPOINTS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "checkpoints")
|
| 8 |
+
DEST = os.path.join(CHECKPOINTS_DIR, "L2CSNet_gaze360.pkl")
|
| 9 |
+
GDRIVE_ID = "1dL2Jokb19_SBSHAhKHOxJsmYs5-GoyLo"
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def main():
|
| 13 |
+
if os.path.isfile(DEST):
|
| 14 |
+
print(f"[OK] Weights already at {DEST}")
|
| 15 |
+
return
|
| 16 |
+
|
| 17 |
+
try:
|
| 18 |
+
import gdown
|
| 19 |
+
except ImportError:
|
| 20 |
+
print("gdown not installed. Run: pip install gdown")
|
| 21 |
+
sys.exit(1)
|
| 22 |
+
|
| 23 |
+
os.makedirs(CHECKPOINTS_DIR, exist_ok=True)
|
| 24 |
+
print(f"Downloading L2CS-Net weights to {DEST} ...")
|
| 25 |
+
gdown.download(f"https://drive.google.com/uc?id={GDRIVE_ID}", DEST, quiet=False)
|
| 26 |
+
|
| 27 |
+
if os.path.isfile(DEST):
|
| 28 |
+
print(f"[OK] Downloaded ({os.path.getsize(DEST) / 1024 / 1024:.1f} MB)")
|
| 29 |
+
else:
|
| 30 |
+
print("[ERR] Download failed. Manual download:")
|
| 31 |
+
print(" https://drive.google.com/drive/folders/17p6ORr-JQJcw-eYtG2WGNiuS_qVKwdWd")
|
| 32 |
+
print(f" Place L2CSNet_gaze360.pkl in {CHECKPOINTS_DIR}/")
|
| 33 |
+
sys.exit(1)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
if __name__ == "__main__":
|
| 37 |
+
main()
|
main.py
CHANGED
|
@@ -22,7 +22,10 @@ from aiortc import RTCPeerConnection, RTCSessionDescription, VideoStreamTrack
|
|
| 22 |
from av import VideoFrame
|
| 23 |
|
| 24 |
from mediapipe.tasks.python.vision import FaceLandmarksConnections
|
| 25 |
-
from ui.pipeline import
|
|
|
|
|
|
|
|
|
|
| 26 |
from models.face_mesh import FaceMeshDetector
|
| 27 |
|
| 28 |
# ================ FACE MESH DRAWING (server-side, for WebRTC) ================
|
|
@@ -164,6 +167,7 @@ app.add_middleware(
|
|
| 164 |
db_path = "focus_guard.db"
|
| 165 |
pcs = set()
|
| 166 |
_cached_model_name = "mlp" # in-memory cache, updated via /api/settings
|
|
|
|
| 167 |
|
| 168 |
async def _wait_for_ice_gathering(pc: RTCPeerConnection):
|
| 169 |
if pc.iceGatheringState == "complete":
|
|
@@ -243,6 +247,7 @@ class SettingsUpdate(BaseModel):
|
|
| 243 |
notification_threshold: Optional[int] = None
|
| 244 |
frame_rate: Optional[int] = None
|
| 245 |
model_name: Optional[str] = None
|
|
|
|
| 246 |
|
| 247 |
class VideoTransformTrack(VideoStreamTrack):
|
| 248 |
def __init__(self, track, session_id: int, get_channel: Callable[[], Any]):
|
|
@@ -270,6 +275,8 @@ class VideoTransformTrack(VideoStreamTrack):
|
|
| 270 |
self.last_inference_time = now
|
| 271 |
|
| 272 |
model_name = _cached_model_name
|
|
|
|
|
|
|
| 273 |
if model_name not in pipelines or pipelines.get(model_name) is None:
|
| 274 |
model_name = 'mlp'
|
| 275 |
active_pipeline = pipelines.get(model_name)
|
|
@@ -455,6 +462,7 @@ pipelines = {
|
|
| 455 |
"mlp": None,
|
| 456 |
"hybrid": None,
|
| 457 |
"xgboost": None,
|
|
|
|
| 458 |
}
|
| 459 |
|
| 460 |
# Thread pool for CPU-bound inference so the event loop stays responsive.
|
|
@@ -464,14 +472,81 @@ _inference_executor = concurrent.futures.ThreadPoolExecutor(
|
|
| 464 |
)
|
| 465 |
# One lock per pipeline so shared state (TemporalTracker, etc.) is not corrupted when
|
| 466 |
# multiple frames are processed in parallel by the thread pool.
|
| 467 |
-
_pipeline_locks = {name: threading.Lock() for name in ("geometric", "mlp", "hybrid", "xgboost")}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 468 |
|
| 469 |
|
| 470 |
-
def _process_frame_safe(pipeline, frame, model_name
|
| 471 |
-
"""Run process_frame in executor with per-pipeline lock."""
|
| 472 |
with _pipeline_locks[model_name]:
|
| 473 |
return pipeline.process_frame(frame)
|
| 474 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 475 |
@app.on_event("startup")
|
| 476 |
async def startup_event():
|
| 477 |
global pipelines, _cached_model_name
|
|
@@ -509,6 +584,11 @@ async def startup_event():
|
|
| 509 |
except Exception as e:
|
| 510 |
print(f"[ERR] Failed to load XGBoostPipeline: {e}")
|
| 511 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 512 |
@app.on_event("shutdown")
|
| 513 |
async def shutdown_event():
|
| 514 |
_inference_executor.shutdown(wait=False)
|
|
@@ -579,14 +659,19 @@ async def webrtc_offer(offer: dict):
|
|
| 579 |
|
| 580 |
@app.websocket("/ws/video")
|
| 581 |
async def websocket_endpoint(websocket: WebSocket):
|
|
|
|
|
|
|
|
|
|
| 582 |
await websocket.accept()
|
| 583 |
session_id = None
|
| 584 |
frame_count = 0
|
| 585 |
running = True
|
| 586 |
event_buffer = _EventBuffer(flush_interval=2.0)
|
| 587 |
|
|
|
|
|
|
|
|
|
|
| 588 |
# Latest frame slot — only the most recent frame is kept, older ones are dropped.
|
| 589 |
-
# Using a dict so nested functions can mutate without nonlocal issues.
|
| 590 |
_slot = {"frame": None}
|
| 591 |
_frame_ready = asyncio.Event()
|
| 592 |
|
|
@@ -617,7 +702,6 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
| 617 |
data = json.loads(text)
|
| 618 |
|
| 619 |
if data["type"] == "frame":
|
| 620 |
-
# Legacy base64 path (fallback)
|
| 621 |
_slot["frame"] = base64.b64decode(data["image"])
|
| 622 |
_frame_ready.set()
|
| 623 |
|
|
@@ -636,6 +720,47 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
| 636 |
if summary:
|
| 637 |
await websocket.send_json({"type": "session_ended", "summary": summary})
|
| 638 |
session_id = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 639 |
except WebSocketDisconnect:
|
| 640 |
running = False
|
| 641 |
_frame_ready.set()
|
|
@@ -654,7 +779,6 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
| 654 |
if not running:
|
| 655 |
return
|
| 656 |
|
| 657 |
-
# Grab latest frame and clear slot
|
| 658 |
raw = _slot["frame"]
|
| 659 |
_slot["frame"] = None
|
| 660 |
if raw is None:
|
|
@@ -667,38 +791,87 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
| 667 |
continue
|
| 668 |
frame = cv2.resize(frame, (640, 480))
|
| 669 |
|
| 670 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 671 |
if model_name not in pipelines or pipelines.get(model_name) is None:
|
| 672 |
model_name = "mlp"
|
| 673 |
active_pipeline = pipelines.get(model_name)
|
| 674 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 675 |
landmarks_list = None
|
|
|
|
| 676 |
if active_pipeline is not None:
|
| 677 |
-
|
| 678 |
-
|
| 679 |
-
|
| 680 |
-
|
| 681 |
-
|
| 682 |
-
|
| 683 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 684 |
is_focused = out["is_focused"]
|
| 685 |
confidence = out.get("mlp_prob", out.get("raw_score", 0.0))
|
| 686 |
|
| 687 |
lm = out.get("landmarks")
|
| 688 |
if lm is not None:
|
| 689 |
-
# Send all 478 landmarks as flat array for tessellation drawing
|
| 690 |
landmarks_list = [
|
| 691 |
[round(float(lm[i, 0]), 3), round(float(lm[i, 1]), 3)]
|
| 692 |
for i in range(lm.shape[0])
|
| 693 |
]
|
| 694 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 695 |
if session_id:
|
| 696 |
-
|
| 697 |
"s_face": out.get("s_face", 0.0),
|
| 698 |
"s_eye": out.get("s_eye", 0.0),
|
| 699 |
"mar": out.get("mar", 0.0),
|
| 700 |
"model": model_name,
|
| 701 |
-
}
|
|
|
|
| 702 |
else:
|
| 703 |
is_focused = False
|
| 704 |
confidence = 0.0
|
|
@@ -710,8 +883,7 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
| 710 |
"model": model_name,
|
| 711 |
"fc": frame_count,
|
| 712 |
}
|
| 713 |
-
if
|
| 714 |
-
# Send detailed metrics for HUD
|
| 715 |
if out.get("yaw") is not None:
|
| 716 |
resp["yaw"] = round(out["yaw"], 1)
|
| 717 |
resp["pitch"] = round(out["pitch"], 1)
|
|
@@ -720,6 +892,24 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
| 720 |
resp["mar"] = round(out["mar"], 3)
|
| 721 |
resp["sf"] = round(out.get("s_face", 0), 3)
|
| 722 |
resp["se"] = round(out.get("s_eye", 0), 3)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 723 |
if landmarks_list is not None:
|
| 724 |
resp["lm"] = landmarks_list
|
| 725 |
await websocket.send_json(resp)
|
|
@@ -852,8 +1042,9 @@ async def get_settings():
|
|
| 852 |
db.row_factory = aiosqlite.Row
|
| 853 |
cursor = await db.execute("SELECT * FROM user_settings WHERE id = 1")
|
| 854 |
row = await cursor.fetchone()
|
| 855 |
-
|
| 856 |
-
|
|
|
|
| 857 |
|
| 858 |
@app.put("/api/settings")
|
| 859 |
async def update_settings(settings: SettingsUpdate):
|
|
@@ -878,12 +1069,28 @@ async def update_settings(settings: SettingsUpdate):
|
|
| 878 |
if settings.frame_rate is not None:
|
| 879 |
updates.append("frame_rate = ?")
|
| 880 |
params.append(max(5, min(60, settings.frame_rate)))
|
| 881 |
-
if settings.model_name is not None and settings.model_name in pipelines
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 882 |
updates.append("model_name = ?")
|
| 883 |
params.append(settings.model_name)
|
| 884 |
global _cached_model_name
|
| 885 |
_cached_model_name = settings.model_name
|
| 886 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 887 |
if updates:
|
| 888 |
query = f"UPDATE user_settings SET {', '.join(updates)} WHERE id = 1"
|
| 889 |
await db.execute(query, params)
|
|
@@ -919,15 +1126,55 @@ async def get_stats_summary():
|
|
| 919 |
|
| 920 |
@app.get("/api/models")
|
| 921 |
async def get_available_models():
|
| 922 |
-
"""Return
|
| 923 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 924 |
async with aiosqlite.connect(db_path) as db:
|
| 925 |
cursor = await db.execute("SELECT model_name FROM user_settings WHERE id = 1")
|
| 926 |
row = await cursor.fetchone()
|
| 927 |
current = row[0] if row else "mlp"
|
| 928 |
if current not in available and available:
|
| 929 |
current = available[0]
|
| 930 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 931 |
|
| 932 |
@app.get("/api/mesh-topology")
|
| 933 |
async def get_mesh_topology():
|
|
|
|
| 22 |
from av import VideoFrame
|
| 23 |
|
| 24 |
from mediapipe.tasks.python.vision import FaceLandmarksConnections
|
| 25 |
+
from ui.pipeline import (
|
| 26 |
+
FaceMeshPipeline, MLPPipeline, HybridFocusPipeline, XGBoostPipeline,
|
| 27 |
+
L2CSPipeline, is_l2cs_weights_available,
|
| 28 |
+
)
|
| 29 |
from models.face_mesh import FaceMeshDetector
|
| 30 |
|
| 31 |
# ================ FACE MESH DRAWING (server-side, for WebRTC) ================
|
|
|
|
| 167 |
db_path = "focus_guard.db"
|
| 168 |
pcs = set()
|
| 169 |
_cached_model_name = "mlp" # in-memory cache, updated via /api/settings
|
| 170 |
+
_l2cs_boost_enabled = False # when True, L2CS runs alongside the base model
|
| 171 |
|
| 172 |
async def _wait_for_ice_gathering(pc: RTCPeerConnection):
|
| 173 |
if pc.iceGatheringState == "complete":
|
|
|
|
| 247 |
notification_threshold: Optional[int] = None
|
| 248 |
frame_rate: Optional[int] = None
|
| 249 |
model_name: Optional[str] = None
|
| 250 |
+
l2cs_boost: Optional[bool] = None
|
| 251 |
|
| 252 |
class VideoTransformTrack(VideoStreamTrack):
|
| 253 |
def __init__(self, track, session_id: int, get_channel: Callable[[], Any]):
|
|
|
|
| 275 |
self.last_inference_time = now
|
| 276 |
|
| 277 |
model_name = _cached_model_name
|
| 278 |
+
if model_name == "l2cs" and pipelines.get("l2cs") is None:
|
| 279 |
+
_ensure_l2cs()
|
| 280 |
if model_name not in pipelines or pipelines.get(model_name) is None:
|
| 281 |
model_name = 'mlp'
|
| 282 |
active_pipeline = pipelines.get(model_name)
|
|
|
|
| 462 |
"mlp": None,
|
| 463 |
"hybrid": None,
|
| 464 |
"xgboost": None,
|
| 465 |
+
"l2cs": None,
|
| 466 |
}
|
| 467 |
|
| 468 |
# Thread pool for CPU-bound inference so the event loop stays responsive.
|
|
|
|
| 472 |
)
|
| 473 |
# One lock per pipeline so shared state (TemporalTracker, etc.) is not corrupted when
|
| 474 |
# multiple frames are processed in parallel by the thread pool.
|
| 475 |
+
_pipeline_locks = {name: threading.Lock() for name in ("geometric", "mlp", "hybrid", "xgboost", "l2cs")}
|
| 476 |
+
|
| 477 |
+
_l2cs_load_lock = threading.Lock()
|
| 478 |
+
_l2cs_error: str | None = None
|
| 479 |
+
|
| 480 |
+
|
| 481 |
+
def _ensure_l2cs():
|
| 482 |
+
# lazy-load L2CS on first use, double-checked locking
|
| 483 |
+
global _l2cs_error
|
| 484 |
+
if pipelines["l2cs"] is not None:
|
| 485 |
+
return True
|
| 486 |
+
with _l2cs_load_lock:
|
| 487 |
+
if pipelines["l2cs"] is not None:
|
| 488 |
+
return True
|
| 489 |
+
if not is_l2cs_weights_available():
|
| 490 |
+
_l2cs_error = "Weights not found"
|
| 491 |
+
return False
|
| 492 |
+
try:
|
| 493 |
+
pipelines["l2cs"] = L2CSPipeline()
|
| 494 |
+
_l2cs_error = None
|
| 495 |
+
print("[OK] L2CSPipeline lazy-loaded")
|
| 496 |
+
return True
|
| 497 |
+
except Exception as e:
|
| 498 |
+
_l2cs_error = str(e)
|
| 499 |
+
print(f"[ERR] L2CS lazy-load failed: {e}")
|
| 500 |
+
return False
|
| 501 |
|
| 502 |
|
| 503 |
+
def _process_frame_safe(pipeline, frame, model_name):
|
|
|
|
| 504 |
with _pipeline_locks[model_name]:
|
| 505 |
return pipeline.process_frame(frame)
|
| 506 |
|
| 507 |
+
|
| 508 |
+
_BOOST_BASE_W = 0.35
|
| 509 |
+
_BOOST_L2CS_W = 0.65
|
| 510 |
+
_BOOST_VETO = 0.38 # L2CS below this -> forced not-focused
|
| 511 |
+
|
| 512 |
+
|
| 513 |
+
def _process_frame_with_l2cs_boost(base_pipeline, frame, base_model_name):
|
| 514 |
+
# run base model
|
| 515 |
+
with _pipeline_locks[base_model_name]:
|
| 516 |
+
base_out = base_pipeline.process_frame(frame)
|
| 517 |
+
|
| 518 |
+
l2cs_pipe = pipelines.get("l2cs")
|
| 519 |
+
if l2cs_pipe is None:
|
| 520 |
+
base_out["boost_active"] = False
|
| 521 |
+
return base_out
|
| 522 |
+
|
| 523 |
+
# run L2CS
|
| 524 |
+
with _pipeline_locks["l2cs"]:
|
| 525 |
+
l2cs_out = l2cs_pipe.process_frame(frame)
|
| 526 |
+
|
| 527 |
+
base_score = base_out.get("mlp_prob", base_out.get("raw_score", 0.0))
|
| 528 |
+
l2cs_score = l2cs_out.get("raw_score", 0.0)
|
| 529 |
+
|
| 530 |
+
# veto: gaze clearly off-screen overrides base model
|
| 531 |
+
if l2cs_score < _BOOST_VETO:
|
| 532 |
+
fused_score = l2cs_score * 0.8
|
| 533 |
+
is_focused = False
|
| 534 |
+
else:
|
| 535 |
+
fused_score = _BOOST_BASE_W * base_score + _BOOST_L2CS_W * l2cs_score
|
| 536 |
+
is_focused = fused_score >= 0.52
|
| 537 |
+
|
| 538 |
+
base_out["raw_score"] = fused_score
|
| 539 |
+
base_out["is_focused"] = is_focused
|
| 540 |
+
base_out["boost_active"] = True
|
| 541 |
+
base_out["base_score"] = round(base_score, 3)
|
| 542 |
+
base_out["l2cs_score"] = round(l2cs_score, 3)
|
| 543 |
+
|
| 544 |
+
if l2cs_out.get("gaze_yaw") is not None:
|
| 545 |
+
base_out["gaze_yaw"] = l2cs_out["gaze_yaw"]
|
| 546 |
+
base_out["gaze_pitch"] = l2cs_out["gaze_pitch"]
|
| 547 |
+
|
| 548 |
+
return base_out
|
| 549 |
+
|
| 550 |
@app.on_event("startup")
|
| 551 |
async def startup_event():
|
| 552 |
global pipelines, _cached_model_name
|
|
|
|
| 584 |
except Exception as e:
|
| 585 |
print(f"[ERR] Failed to load XGBoostPipeline: {e}")
|
| 586 |
|
| 587 |
+
if is_l2cs_weights_available():
|
| 588 |
+
print("[OK] L2CS weights found — pipeline will be lazy-loaded on first use")
|
| 589 |
+
else:
|
| 590 |
+
print("[WARN] L2CS weights not found — l2cs model unavailable")
|
| 591 |
+
|
| 592 |
@app.on_event("shutdown")
|
| 593 |
async def shutdown_event():
|
| 594 |
_inference_executor.shutdown(wait=False)
|
|
|
|
| 659 |
|
| 660 |
@app.websocket("/ws/video")
|
| 661 |
async def websocket_endpoint(websocket: WebSocket):
|
| 662 |
+
from models.gaze_calibration import GazeCalibration
|
| 663 |
+
from models.gaze_eye_fusion import GazeEyeFusion
|
| 664 |
+
|
| 665 |
await websocket.accept()
|
| 666 |
session_id = None
|
| 667 |
frame_count = 0
|
| 668 |
running = True
|
| 669 |
event_buffer = _EventBuffer(flush_interval=2.0)
|
| 670 |
|
| 671 |
+
# Calibration state (per-connection)
|
| 672 |
+
_cal: dict = {"cal": None, "collecting": False, "fusion": None}
|
| 673 |
+
|
| 674 |
# Latest frame slot — only the most recent frame is kept, older ones are dropped.
|
|
|
|
| 675 |
_slot = {"frame": None}
|
| 676 |
_frame_ready = asyncio.Event()
|
| 677 |
|
|
|
|
| 702 |
data = json.loads(text)
|
| 703 |
|
| 704 |
if data["type"] == "frame":
|
|
|
|
| 705 |
_slot["frame"] = base64.b64decode(data["image"])
|
| 706 |
_frame_ready.set()
|
| 707 |
|
|
|
|
| 720 |
if summary:
|
| 721 |
await websocket.send_json({"type": "session_ended", "summary": summary})
|
| 722 |
session_id = None
|
| 723 |
+
|
| 724 |
+
# ---- Calibration commands ----
|
| 725 |
+
elif data["type"] == "calibration_start":
|
| 726 |
+
loop = asyncio.get_event_loop()
|
| 727 |
+
await loop.run_in_executor(_inference_executor, _ensure_l2cs)
|
| 728 |
+
_cal["cal"] = GazeCalibration()
|
| 729 |
+
_cal["collecting"] = True
|
| 730 |
+
_cal["fusion"] = None
|
| 731 |
+
cal = _cal["cal"]
|
| 732 |
+
await websocket.send_json({
|
| 733 |
+
"type": "calibration_started",
|
| 734 |
+
"num_points": cal.num_points,
|
| 735 |
+
"target": list(cal.current_target),
|
| 736 |
+
"index": cal.current_index,
|
| 737 |
+
})
|
| 738 |
+
|
| 739 |
+
elif data["type"] == "calibration_next":
|
| 740 |
+
cal = _cal.get("cal")
|
| 741 |
+
if cal is not None:
|
| 742 |
+
more = cal.advance()
|
| 743 |
+
if more:
|
| 744 |
+
await websocket.send_json({
|
| 745 |
+
"type": "calibration_point",
|
| 746 |
+
"target": list(cal.current_target),
|
| 747 |
+
"index": cal.current_index,
|
| 748 |
+
})
|
| 749 |
+
else:
|
| 750 |
+
_cal["collecting"] = False
|
| 751 |
+
ok = cal.fit()
|
| 752 |
+
if ok:
|
| 753 |
+
_cal["fusion"] = GazeEyeFusion(cal)
|
| 754 |
+
await websocket.send_json({"type": "calibration_done", "success": True})
|
| 755 |
+
else:
|
| 756 |
+
await websocket.send_json({"type": "calibration_done", "success": False, "error": "Not enough samples"})
|
| 757 |
+
|
| 758 |
+
elif data["type"] == "calibration_cancel":
|
| 759 |
+
_cal["cal"] = None
|
| 760 |
+
_cal["collecting"] = False
|
| 761 |
+
_cal["fusion"] = None
|
| 762 |
+
await websocket.send_json({"type": "calibration_cancelled"})
|
| 763 |
+
|
| 764 |
except WebSocketDisconnect:
|
| 765 |
running = False
|
| 766 |
_frame_ready.set()
|
|
|
|
| 779 |
if not running:
|
| 780 |
return
|
| 781 |
|
|
|
|
| 782 |
raw = _slot["frame"]
|
| 783 |
_slot["frame"] = None
|
| 784 |
if raw is None:
|
|
|
|
| 791 |
continue
|
| 792 |
frame = cv2.resize(frame, (640, 480))
|
| 793 |
|
| 794 |
+
# During calibration collection, always use L2CS
|
| 795 |
+
collecting = _cal.get("collecting", False)
|
| 796 |
+
if collecting:
|
| 797 |
+
if pipelines.get("l2cs") is None:
|
| 798 |
+
await loop.run_in_executor(_inference_executor, _ensure_l2cs)
|
| 799 |
+
use_model = "l2cs" if pipelines.get("l2cs") is not None else _cached_model_name
|
| 800 |
+
else:
|
| 801 |
+
use_model = _cached_model_name
|
| 802 |
+
|
| 803 |
+
model_name = use_model
|
| 804 |
+
if model_name == "l2cs" and pipelines.get("l2cs") is None:
|
| 805 |
+
await loop.run_in_executor(_inference_executor, _ensure_l2cs)
|
| 806 |
if model_name not in pipelines or pipelines.get(model_name) is None:
|
| 807 |
model_name = "mlp"
|
| 808 |
active_pipeline = pipelines.get(model_name)
|
| 809 |
|
| 810 |
+
# L2CS boost: run L2CS alongside base model
|
| 811 |
+
use_boost = (
|
| 812 |
+
_l2cs_boost_enabled
|
| 813 |
+
and model_name != "l2cs"
|
| 814 |
+
and pipelines.get("l2cs") is not None
|
| 815 |
+
and not collecting
|
| 816 |
+
)
|
| 817 |
+
|
| 818 |
landmarks_list = None
|
| 819 |
+
out = None
|
| 820 |
if active_pipeline is not None:
|
| 821 |
+
if use_boost:
|
| 822 |
+
out = await loop.run_in_executor(
|
| 823 |
+
_inference_executor,
|
| 824 |
+
_process_frame_with_l2cs_boost,
|
| 825 |
+
active_pipeline,
|
| 826 |
+
frame,
|
| 827 |
+
model_name,
|
| 828 |
+
)
|
| 829 |
+
else:
|
| 830 |
+
out = await loop.run_in_executor(
|
| 831 |
+
_inference_executor,
|
| 832 |
+
_process_frame_safe,
|
| 833 |
+
active_pipeline,
|
| 834 |
+
frame,
|
| 835 |
+
model_name,
|
| 836 |
+
)
|
| 837 |
is_focused = out["is_focused"]
|
| 838 |
confidence = out.get("mlp_prob", out.get("raw_score", 0.0))
|
| 839 |
|
| 840 |
lm = out.get("landmarks")
|
| 841 |
if lm is not None:
|
|
|
|
| 842 |
landmarks_list = [
|
| 843 |
[round(float(lm[i, 0]), 3), round(float(lm[i, 1]), 3)]
|
| 844 |
for i in range(lm.shape[0])
|
| 845 |
]
|
| 846 |
|
| 847 |
+
# Calibration sample collection (L2CS gaze angles)
|
| 848 |
+
if collecting and _cal.get("cal") is not None:
|
| 849 |
+
pipe_yaw = out.get("gaze_yaw")
|
| 850 |
+
pipe_pitch = out.get("gaze_pitch")
|
| 851 |
+
if pipe_yaw is not None and pipe_pitch is not None:
|
| 852 |
+
_cal["cal"].collect_sample(pipe_yaw, pipe_pitch)
|
| 853 |
+
|
| 854 |
+
# Gaze fusion (when L2CS active + calibration fitted)
|
| 855 |
+
fusion = _cal.get("fusion")
|
| 856 |
+
if (
|
| 857 |
+
fusion is not None
|
| 858 |
+
and model_name == "l2cs"
|
| 859 |
+
and out.get("gaze_yaw") is not None
|
| 860 |
+
):
|
| 861 |
+
fuse = fusion.update(
|
| 862 |
+
out["gaze_yaw"], out["gaze_pitch"], lm
|
| 863 |
+
)
|
| 864 |
+
is_focused = fuse["focused"]
|
| 865 |
+
confidence = fuse["focus_score"]
|
| 866 |
+
|
| 867 |
if session_id:
|
| 868 |
+
metadata = {
|
| 869 |
"s_face": out.get("s_face", 0.0),
|
| 870 |
"s_eye": out.get("s_eye", 0.0),
|
| 871 |
"mar": out.get("mar", 0.0),
|
| 872 |
"model": model_name,
|
| 873 |
+
}
|
| 874 |
+
event_buffer.add(session_id, is_focused, confidence, metadata)
|
| 875 |
else:
|
| 876 |
is_focused = False
|
| 877 |
confidence = 0.0
|
|
|
|
| 883 |
"model": model_name,
|
| 884 |
"fc": frame_count,
|
| 885 |
}
|
| 886 |
+
if out is not None:
|
|
|
|
| 887 |
if out.get("yaw") is not None:
|
| 888 |
resp["yaw"] = round(out["yaw"], 1)
|
| 889 |
resp["pitch"] = round(out["pitch"], 1)
|
|
|
|
| 892 |
resp["mar"] = round(out["mar"], 3)
|
| 893 |
resp["sf"] = round(out.get("s_face", 0), 3)
|
| 894 |
resp["se"] = round(out.get("s_eye", 0), 3)
|
| 895 |
+
|
| 896 |
+
# Gaze fusion fields (L2CS standalone or boost mode)
|
| 897 |
+
fusion = _cal.get("fusion")
|
| 898 |
+
has_gaze = out.get("gaze_yaw") is not None
|
| 899 |
+
if fusion is not None and has_gaze and (model_name == "l2cs" or use_boost):
|
| 900 |
+
fuse = fusion.update(out["gaze_yaw"], out["gaze_pitch"], out.get("landmarks"))
|
| 901 |
+
resp["gaze_x"] = fuse["gaze_x"]
|
| 902 |
+
resp["gaze_y"] = fuse["gaze_y"]
|
| 903 |
+
resp["on_screen"] = fuse["on_screen"]
|
| 904 |
+
if model_name == "l2cs":
|
| 905 |
+
resp["focused"] = fuse["focused"]
|
| 906 |
+
resp["confidence"] = round(fuse["focus_score"], 3)
|
| 907 |
+
|
| 908 |
+
if out.get("boost_active"):
|
| 909 |
+
resp["boost"] = True
|
| 910 |
+
resp["base_score"] = out.get("base_score", 0)
|
| 911 |
+
resp["l2cs_score"] = out.get("l2cs_score", 0)
|
| 912 |
+
|
| 913 |
if landmarks_list is not None:
|
| 914 |
resp["lm"] = landmarks_list
|
| 915 |
await websocket.send_json(resp)
|
|
|
|
| 1042 |
db.row_factory = aiosqlite.Row
|
| 1043 |
cursor = await db.execute("SELECT * FROM user_settings WHERE id = 1")
|
| 1044 |
row = await cursor.fetchone()
|
| 1045 |
+
result = dict(row) if row else {'sensitivity': 6, 'notification_enabled': True, 'notification_threshold': 30, 'frame_rate': 30, 'model_name': 'mlp'}
|
| 1046 |
+
result['l2cs_boost'] = _l2cs_boost_enabled
|
| 1047 |
+
return result
|
| 1048 |
|
| 1049 |
@app.put("/api/settings")
|
| 1050 |
async def update_settings(settings: SettingsUpdate):
|
|
|
|
| 1069 |
if settings.frame_rate is not None:
|
| 1070 |
updates.append("frame_rate = ?")
|
| 1071 |
params.append(max(5, min(60, settings.frame_rate)))
|
| 1072 |
+
if settings.model_name is not None and settings.model_name in pipelines:
|
| 1073 |
+
if settings.model_name == "l2cs":
|
| 1074 |
+
loop = asyncio.get_event_loop()
|
| 1075 |
+
loaded = await loop.run_in_executor(_inference_executor, _ensure_l2cs)
|
| 1076 |
+
if not loaded:
|
| 1077 |
+
raise HTTPException(status_code=400, detail=f"L2CS model unavailable: {_l2cs_error}")
|
| 1078 |
+
elif pipelines[settings.model_name] is None:
|
| 1079 |
+
raise HTTPException(status_code=400, detail=f"Model '{settings.model_name}' not loaded")
|
| 1080 |
updates.append("model_name = ?")
|
| 1081 |
params.append(settings.model_name)
|
| 1082 |
global _cached_model_name
|
| 1083 |
_cached_model_name = settings.model_name
|
| 1084 |
|
| 1085 |
+
if settings.l2cs_boost is not None:
|
| 1086 |
+
global _l2cs_boost_enabled
|
| 1087 |
+
if settings.l2cs_boost:
|
| 1088 |
+
loop = asyncio.get_event_loop()
|
| 1089 |
+
loaded = await loop.run_in_executor(_inference_executor, _ensure_l2cs)
|
| 1090 |
+
if not loaded:
|
| 1091 |
+
raise HTTPException(status_code=400, detail=f"L2CS boost unavailable: {_l2cs_error}")
|
| 1092 |
+
_l2cs_boost_enabled = settings.l2cs_boost
|
| 1093 |
+
|
| 1094 |
if updates:
|
| 1095 |
query = f"UPDATE user_settings SET {', '.join(updates)} WHERE id = 1"
|
| 1096 |
await db.execute(query, params)
|
|
|
|
| 1126 |
|
| 1127 |
@app.get("/api/models")
|
| 1128 |
async def get_available_models():
|
| 1129 |
+
"""Return model names, statuses, and which is currently active."""
|
| 1130 |
+
statuses = {}
|
| 1131 |
+
errors = {}
|
| 1132 |
+
available = []
|
| 1133 |
+
for name, p in pipelines.items():
|
| 1134 |
+
if name == "l2cs":
|
| 1135 |
+
if p is not None:
|
| 1136 |
+
statuses[name] = "ready"
|
| 1137 |
+
available.append(name)
|
| 1138 |
+
elif is_l2cs_weights_available():
|
| 1139 |
+
statuses[name] = "lazy"
|
| 1140 |
+
available.append(name)
|
| 1141 |
+
elif _l2cs_error:
|
| 1142 |
+
statuses[name] = "error"
|
| 1143 |
+
errors[name] = _l2cs_error
|
| 1144 |
+
else:
|
| 1145 |
+
statuses[name] = "unavailable"
|
| 1146 |
+
elif p is not None:
|
| 1147 |
+
statuses[name] = "ready"
|
| 1148 |
+
available.append(name)
|
| 1149 |
+
else:
|
| 1150 |
+
statuses[name] = "unavailable"
|
| 1151 |
async with aiosqlite.connect(db_path) as db:
|
| 1152 |
cursor = await db.execute("SELECT model_name FROM user_settings WHERE id = 1")
|
| 1153 |
row = await cursor.fetchone()
|
| 1154 |
current = row[0] if row else "mlp"
|
| 1155 |
if current not in available and available:
|
| 1156 |
current = available[0]
|
| 1157 |
+
l2cs_boost_available = (
|
| 1158 |
+
statuses.get("l2cs") in ("ready", "lazy") and current != "l2cs"
|
| 1159 |
+
)
|
| 1160 |
+
return {
|
| 1161 |
+
"available": available,
|
| 1162 |
+
"current": current,
|
| 1163 |
+
"statuses": statuses,
|
| 1164 |
+
"errors": errors,
|
| 1165 |
+
"l2cs_boost": _l2cs_boost_enabled,
|
| 1166 |
+
"l2cs_boost_available": l2cs_boost_available,
|
| 1167 |
+
}
|
| 1168 |
+
|
| 1169 |
+
@app.get("/api/l2cs/status")
|
| 1170 |
+
async def l2cs_status():
|
| 1171 |
+
"""L2CS-specific status: weights available, loaded, and calibration info."""
|
| 1172 |
+
loaded = pipelines.get("l2cs") is not None
|
| 1173 |
+
return {
|
| 1174 |
+
"weights_available": is_l2cs_weights_available(),
|
| 1175 |
+
"loaded": loaded,
|
| 1176 |
+
"error": _l2cs_error,
|
| 1177 |
+
}
|
| 1178 |
|
| 1179 |
@app.get("/api/mesh-topology")
|
| 1180 |
async def get_mesh_topology():
|
models/L2CS-Net/.gitignore
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Ignore the test data - sensitive
|
| 2 |
+
datasets/
|
| 3 |
+
evaluation/
|
| 4 |
+
output/
|
| 5 |
+
|
| 6 |
+
# Ignore debugging configurations
|
| 7 |
+
/.vscode
|
| 8 |
+
|
| 9 |
+
# Byte-compiled / optimized / DLL files
|
| 10 |
+
__pycache__/
|
| 11 |
+
*.py[cod]
|
| 12 |
+
*$py.class
|
| 13 |
+
|
| 14 |
+
# C extensions
|
| 15 |
+
*.so
|
| 16 |
+
|
| 17 |
+
# Distribution / packaging
|
| 18 |
+
.Python
|
| 19 |
+
build/
|
| 20 |
+
develop-eggs/
|
| 21 |
+
dist/
|
| 22 |
+
downloads/
|
| 23 |
+
eggs/
|
| 24 |
+
.eggs/
|
| 25 |
+
lib/
|
| 26 |
+
lib64/
|
| 27 |
+
parts/
|
| 28 |
+
sdist/
|
| 29 |
+
var/
|
| 30 |
+
wheels/
|
| 31 |
+
pip-wheel-metadata/
|
| 32 |
+
share/python-wheels/
|
| 33 |
+
*.egg-info/
|
| 34 |
+
.installed.cfg
|
| 35 |
+
*.egg
|
| 36 |
+
MANIFEST
|
| 37 |
+
|
| 38 |
+
# PyInstaller
|
| 39 |
+
# Usually these files are written by a python script from a template
|
| 40 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 41 |
+
*.manifest
|
| 42 |
+
*.spec
|
| 43 |
+
|
| 44 |
+
# Installer logs
|
| 45 |
+
pip-log.txt
|
| 46 |
+
pip-delete-this-directory.txt
|
| 47 |
+
|
| 48 |
+
# Unit test / coverage reports
|
| 49 |
+
htmlcov/
|
| 50 |
+
.tox/
|
| 51 |
+
.nox/
|
| 52 |
+
.coverage
|
| 53 |
+
.coverage.*
|
| 54 |
+
.cache
|
| 55 |
+
nosetests.xml
|
| 56 |
+
coverage.xml
|
| 57 |
+
*.cover
|
| 58 |
+
*.py,cover
|
| 59 |
+
.hypothesis/
|
| 60 |
+
.pytest_cache/
|
| 61 |
+
|
| 62 |
+
# Translations
|
| 63 |
+
*.mo
|
| 64 |
+
*.pot
|
| 65 |
+
|
| 66 |
+
# Django stuff:
|
| 67 |
+
*.log
|
| 68 |
+
local_settings.py
|
| 69 |
+
db.sqlite3
|
| 70 |
+
db.sqlite3-journal
|
| 71 |
+
|
| 72 |
+
# Flask stuff:
|
| 73 |
+
instance/
|
| 74 |
+
.webassets-cache
|
| 75 |
+
|
| 76 |
+
# Scrapy stuff:
|
| 77 |
+
.scrapy
|
| 78 |
+
|
| 79 |
+
# Sphinx documentation
|
| 80 |
+
docs/_build/
|
| 81 |
+
|
| 82 |
+
# PyBuilder
|
| 83 |
+
target/
|
| 84 |
+
|
| 85 |
+
# Jupyter Notebook
|
| 86 |
+
.ipynb_checkpoints
|
| 87 |
+
|
| 88 |
+
# IPython
|
| 89 |
+
profile_default/
|
| 90 |
+
ipython_config.py
|
| 91 |
+
|
| 92 |
+
# pyenv
|
| 93 |
+
.python-version
|
| 94 |
+
|
| 95 |
+
# pipenv
|
| 96 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 97 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 98 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 99 |
+
# install all needed dependencies.
|
| 100 |
+
#Pipfile.lock
|
| 101 |
+
|
| 102 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
| 103 |
+
__pypackages__/
|
| 104 |
+
|
| 105 |
+
# Celery stuff
|
| 106 |
+
celerybeat-schedule
|
| 107 |
+
celerybeat.pid
|
| 108 |
+
|
| 109 |
+
# SageMath parsed files
|
| 110 |
+
*.sage.py
|
| 111 |
+
|
| 112 |
+
# Environments
|
| 113 |
+
.env
|
| 114 |
+
.venv
|
| 115 |
+
env/
|
| 116 |
+
venv/
|
| 117 |
+
ENV/
|
| 118 |
+
env.bak/
|
| 119 |
+
venv.bak/
|
| 120 |
+
|
| 121 |
+
# Spyder project settings
|
| 122 |
+
.spyderproject
|
| 123 |
+
.spyproject
|
| 124 |
+
|
| 125 |
+
# Rope project settings
|
| 126 |
+
.ropeproject
|
| 127 |
+
|
| 128 |
+
# mkdocs documentation
|
| 129 |
+
/site
|
| 130 |
+
|
| 131 |
+
# mypy
|
| 132 |
+
.mypy_cache/
|
| 133 |
+
.dmypy.json
|
| 134 |
+
dmypy.json
|
| 135 |
+
|
| 136 |
+
# Pyre type checker
|
| 137 |
+
.pyre/
|
| 138 |
+
|
| 139 |
+
# Ignore other files
|
| 140 |
+
my.secrets
|
models/L2CS-Net/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2022 Ahmed Abdelrahman
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
models/L2CS-Net/README.md
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
<p align="center">
|
| 5 |
+
<img src="https://github.com/Ahmednull/Storage/blob/main/gaze.gif" alt="animated" />
|
| 6 |
+
</p>
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
___
|
| 10 |
+
|
| 11 |
+
# L2CS-Net
|
| 12 |
+
|
| 13 |
+
The official PyTorch implementation of L2CS-Net for gaze estimation and tracking.
|
| 14 |
+
|
| 15 |
+
## Installation
|
| 16 |
+
<img src="https://img.shields.io/badge/python%20-%2314354C.svg?&style=for-the-badge&logo=python&logoColor=white"/> <img src="https://img.shields.io/badge/PyTorch%20-%23EE4C2C.svg?&style=for-the-badge&logo=PyTorch&logoColor=white" />
|
| 17 |
+
|
| 18 |
+
Install package with the following:
|
| 19 |
+
|
| 20 |
+
```
|
| 21 |
+
pip install git+https://github.com/Ahmednull/L2CS-Net.git@main
|
| 22 |
+
```
|
| 23 |
+
|
| 24 |
+
Or, you can git clone the repo and install with the following:
|
| 25 |
+
|
| 26 |
+
```
|
| 27 |
+
pip install [-e] .
|
| 28 |
+
```
|
| 29 |
+
|
| 30 |
+
Now you should be able to import the package with the following command:
|
| 31 |
+
|
| 32 |
+
```
|
| 33 |
+
$ python
|
| 34 |
+
>>> import l2cs
|
| 35 |
+
```
|
| 36 |
+
|
| 37 |
+
## Usage
|
| 38 |
+
|
| 39 |
+
Detect face and predict gaze from webcam
|
| 40 |
+
|
| 41 |
+
```python
|
| 42 |
+
from l2cs import Pipeline, render
|
| 43 |
+
import cv2
|
| 44 |
+
|
| 45 |
+
gaze_pipeline = Pipeline(
|
| 46 |
+
weights=CWD / 'models' / 'L2CSNet_gaze360.pkl',
|
| 47 |
+
arch='ResNet50',
|
| 48 |
+
device=torch.device('cpu') # or 'gpu'
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
cap = cv2.VideoCapture(cam)
|
| 52 |
+
_, frame = cap.read()
|
| 53 |
+
|
| 54 |
+
# Process frame and visualize
|
| 55 |
+
results = gaze_pipeline.step(frame)
|
| 56 |
+
frame = render(frame, results)
|
| 57 |
+
```
|
| 58 |
+
|
| 59 |
+
## Demo
|
| 60 |
+
* Download the pre-trained models from [here](https://drive.google.com/drive/folders/17p6ORr-JQJcw-eYtG2WGNiuS_qVKwdWd?usp=sharing) and Store it to *models/*.
|
| 61 |
+
* Run:
|
| 62 |
+
```
|
| 63 |
+
python demo.py \
|
| 64 |
+
--snapshot models/L2CSNet_gaze360.pkl \
|
| 65 |
+
--gpu 0 \
|
| 66 |
+
--cam 0 \
|
| 67 |
+
```
|
| 68 |
+
This means the demo will run using *L2CSNet_gaze360.pkl* pretrained model
|
| 69 |
+
|
| 70 |
+
## Community Contributions
|
| 71 |
+
|
| 72 |
+
- [Gaze Detection and Eye Tracking: A How-To Guide](https://blog.roboflow.com/gaze-direction-position/): Use L2CS-Net through a HTTP interface with the open source Roboflow Inference project.
|
| 73 |
+
|
| 74 |
+
## MPIIGaze
|
| 75 |
+
We provide the code for train and test MPIIGaze dataset with leave-one-person-out evaluation.
|
| 76 |
+
|
| 77 |
+
### Prepare datasets
|
| 78 |
+
* Download **MPIIFaceGaze dataset** from [here](https://www.mpi-inf.mpg.de/departments/computer-vision-and-machine-learning/research/gaze-based-human-computer-interaction/its-written-all-over-your-face-full-face-appearance-based-gaze-estimation).
|
| 79 |
+
* Apply data preprocessing from [here](http://phi-ai.buaa.edu.cn/Gazehub/3D-dataset/).
|
| 80 |
+
* Store the dataset to *datasets/MPIIFaceGaze*.
|
| 81 |
+
|
| 82 |
+
### Train
|
| 83 |
+
```
|
| 84 |
+
python train.py \
|
| 85 |
+
--dataset mpiigaze \
|
| 86 |
+
--snapshot output/snapshots \
|
| 87 |
+
--gpu 0 \
|
| 88 |
+
--num_epochs 50 \
|
| 89 |
+
--batch_size 16 \
|
| 90 |
+
--lr 0.00001 \
|
| 91 |
+
--alpha 1 \
|
| 92 |
+
|
| 93 |
+
```
|
| 94 |
+
This means the code will perform leave-one-person-out training automatically and store the models to *output/snapshots*.
|
| 95 |
+
|
| 96 |
+
### Test
|
| 97 |
+
```
|
| 98 |
+
python test.py \
|
| 99 |
+
--dataset mpiigaze \
|
| 100 |
+
--snapshot output/snapshots/snapshot_folder \
|
| 101 |
+
--evalpath evaluation/L2CS-mpiigaze \
|
| 102 |
+
--gpu 0 \
|
| 103 |
+
```
|
| 104 |
+
This means the code will perform leave-one-person-out testing automatically and store the results to *evaluation/L2CS-mpiigaze*.
|
| 105 |
+
|
| 106 |
+
To get the average leave-one-person-out accuracy use:
|
| 107 |
+
```
|
| 108 |
+
python leave_one_out_eval.py \
|
| 109 |
+
--evalpath evaluation/L2CS-mpiigaze \
|
| 110 |
+
--respath evaluation/L2CS-mpiigaze \
|
| 111 |
+
```
|
| 112 |
+
This means the code will take the evaluation path and outputs the leave-one-out gaze accuracy to the *evaluation/L2CS-mpiigaze*.
|
| 113 |
+
|
| 114 |
+
## Gaze360
|
| 115 |
+
We provide the code for train and test Gaze360 dataset with train-val-test evaluation.
|
| 116 |
+
|
| 117 |
+
### Prepare datasets
|
| 118 |
+
* Download **Gaze360 dataset** from [here](http://gaze360.csail.mit.edu/download.php).
|
| 119 |
+
|
| 120 |
+
* Apply data preprocessing from [here](http://phi-ai.buaa.edu.cn/Gazehub/3D-dataset/).
|
| 121 |
+
|
| 122 |
+
* Store the dataset to *datasets/Gaze360*.
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
### Train
|
| 126 |
+
```
|
| 127 |
+
python train.py \
|
| 128 |
+
--dataset gaze360 \
|
| 129 |
+
--snapshot output/snapshots \
|
| 130 |
+
--gpu 0 \
|
| 131 |
+
--num_epochs 50 \
|
| 132 |
+
--batch_size 16 \
|
| 133 |
+
--lr 0.00001 \
|
| 134 |
+
--alpha 1 \
|
| 135 |
+
|
| 136 |
+
```
|
| 137 |
+
This means the code will perform training and store the models to *output/snapshots*.
|
| 138 |
+
|
| 139 |
+
### Test
|
| 140 |
+
```
|
| 141 |
+
python test.py \
|
| 142 |
+
--dataset gaze360 \
|
| 143 |
+
--snapshot output/snapshots/snapshot_folder \
|
| 144 |
+
--evalpath evaluation/L2CS-gaze360 \
|
| 145 |
+
--gpu 0 \
|
| 146 |
+
```
|
| 147 |
+
This means the code will perform testing on snapshot_folder and store the results to *evaluation/L2CS-gaze360*.
|
| 148 |
+
|
models/L2CS-Net/demo.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import pathlib
|
| 3 |
+
import numpy as np
|
| 4 |
+
import cv2
|
| 5 |
+
import time
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
from torch.autograd import Variable
|
| 10 |
+
from torchvision import transforms
|
| 11 |
+
import torch.backends.cudnn as cudnn
|
| 12 |
+
import torchvision
|
| 13 |
+
|
| 14 |
+
from PIL import Image
|
| 15 |
+
from PIL import Image, ImageOps
|
| 16 |
+
|
| 17 |
+
from face_detection import RetinaFace
|
| 18 |
+
|
| 19 |
+
from l2cs import select_device, draw_gaze, getArch, Pipeline, render
|
| 20 |
+
|
| 21 |
+
CWD = pathlib.Path.cwd()
|
| 22 |
+
|
| 23 |
+
def parse_args():
|
| 24 |
+
"""Parse input arguments."""
|
| 25 |
+
parser = argparse.ArgumentParser(
|
| 26 |
+
description='Gaze evalution using model pretrained with L2CS-Net on Gaze360.')
|
| 27 |
+
parser.add_argument(
|
| 28 |
+
'--device',dest='device', help='Device to run model: cpu or gpu:0',
|
| 29 |
+
default="cpu", type=str)
|
| 30 |
+
parser.add_argument(
|
| 31 |
+
'--snapshot',dest='snapshot', help='Path of model snapshot.',
|
| 32 |
+
default='output/snapshots/L2CS-gaze360-_loader-180-4/_epoch_55.pkl', type=str)
|
| 33 |
+
parser.add_argument(
|
| 34 |
+
'--cam',dest='cam_id', help='Camera device id to use [0]',
|
| 35 |
+
default=0, type=int)
|
| 36 |
+
parser.add_argument(
|
| 37 |
+
'--arch',dest='arch',help='Network architecture, can be: ResNet18, ResNet34, ResNet50, ResNet101, ResNet152',
|
| 38 |
+
default='ResNet50', type=str)
|
| 39 |
+
|
| 40 |
+
args = parser.parse_args()
|
| 41 |
+
return args
|
| 42 |
+
|
| 43 |
+
if __name__ == '__main__':
|
| 44 |
+
args = parse_args()
|
| 45 |
+
|
| 46 |
+
cudnn.enabled = True
|
| 47 |
+
arch=args.arch
|
| 48 |
+
cam = args.cam_id
|
| 49 |
+
# snapshot_path = args.snapshot
|
| 50 |
+
|
| 51 |
+
gaze_pipeline = Pipeline(
|
| 52 |
+
weights=CWD / 'models' / 'L2CSNet_gaze360.pkl',
|
| 53 |
+
arch='ResNet50',
|
| 54 |
+
device = select_device(args.device, batch_size=1)
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
cap = cv2.VideoCapture(cam)
|
| 58 |
+
|
| 59 |
+
# Check if the webcam is opened correctly
|
| 60 |
+
if not cap.isOpened():
|
| 61 |
+
raise IOError("Cannot open webcam")
|
| 62 |
+
|
| 63 |
+
with torch.no_grad():
|
| 64 |
+
while True:
|
| 65 |
+
|
| 66 |
+
# Get frame
|
| 67 |
+
success, frame = cap.read()
|
| 68 |
+
start_fps = time.time()
|
| 69 |
+
|
| 70 |
+
if not success:
|
| 71 |
+
print("Failed to obtain frame")
|
| 72 |
+
time.sleep(0.1)
|
| 73 |
+
|
| 74 |
+
# Process frame
|
| 75 |
+
results = gaze_pipeline.step(frame)
|
| 76 |
+
|
| 77 |
+
# Visualize output
|
| 78 |
+
frame = render(frame, results)
|
| 79 |
+
|
| 80 |
+
myFPS = 1.0 / (time.time() - start_fps)
|
| 81 |
+
cv2.putText(frame, 'FPS: {:.1f}'.format(myFPS), (10, 20),cv2.FONT_HERSHEY_COMPLEX_SMALL, 1, (0, 255, 0), 1, cv2.LINE_AA)
|
| 82 |
+
|
| 83 |
+
cv2.imshow("Demo",frame)
|
| 84 |
+
if cv2.waitKey(1) & 0xFF == ord('q'):
|
| 85 |
+
break
|
| 86 |
+
success,frame = cap.read()
|
| 87 |
+
|
models/L2CS-Net/l2cs/__init__.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .utils import select_device, natural_keys, gazeto3d, angular, getArch
|
| 2 |
+
from .vis import draw_gaze, render
|
| 3 |
+
from .model import L2CS
|
| 4 |
+
from .pipeline import Pipeline
|
| 5 |
+
from .datasets import Gaze360, Mpiigaze
|
| 6 |
+
|
| 7 |
+
__all__ = [
|
| 8 |
+
# Classes
|
| 9 |
+
'L2CS',
|
| 10 |
+
'Pipeline',
|
| 11 |
+
'Gaze360',
|
| 12 |
+
'Mpiigaze',
|
| 13 |
+
# Utils
|
| 14 |
+
'render',
|
| 15 |
+
'select_device',
|
| 16 |
+
'draw_gaze',
|
| 17 |
+
'natural_keys',
|
| 18 |
+
'gazeto3d',
|
| 19 |
+
'angular',
|
| 20 |
+
'getArch'
|
| 21 |
+
]
|
models/L2CS-Net/l2cs/datasets.py
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import numpy as np
|
| 3 |
+
import cv2
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torch.utils.data.dataset import Dataset
|
| 8 |
+
from torchvision import transforms
|
| 9 |
+
from PIL import Image, ImageFilter
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class Gaze360(Dataset):
|
| 13 |
+
def __init__(self, path, root, transform, angle, binwidth, train=True):
|
| 14 |
+
self.transform = transform
|
| 15 |
+
self.root = root
|
| 16 |
+
self.orig_list_len = 0
|
| 17 |
+
self.angle = angle
|
| 18 |
+
if train==False:
|
| 19 |
+
angle=90
|
| 20 |
+
self.binwidth=binwidth
|
| 21 |
+
self.lines = []
|
| 22 |
+
if isinstance(path, list):
|
| 23 |
+
for i in path:
|
| 24 |
+
with open(i) as f:
|
| 25 |
+
print("here")
|
| 26 |
+
line = f.readlines()
|
| 27 |
+
line.pop(0)
|
| 28 |
+
self.lines.extend(line)
|
| 29 |
+
else:
|
| 30 |
+
with open(path) as f:
|
| 31 |
+
lines = f.readlines()
|
| 32 |
+
lines.pop(0)
|
| 33 |
+
self.orig_list_len = len(lines)
|
| 34 |
+
for line in lines:
|
| 35 |
+
gaze2d = line.strip().split(" ")[5]
|
| 36 |
+
label = np.array(gaze2d.split(",")).astype("float")
|
| 37 |
+
if abs((label[0]*180/np.pi)) <= angle and abs((label[1]*180/np.pi)) <= angle:
|
| 38 |
+
self.lines.append(line)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
print("{} items removed from dataset that have an angle > {}".format(self.orig_list_len-len(self.lines), angle))
|
| 42 |
+
|
| 43 |
+
def __len__(self):
|
| 44 |
+
return len(self.lines)
|
| 45 |
+
|
| 46 |
+
def __getitem__(self, idx):
|
| 47 |
+
line = self.lines[idx]
|
| 48 |
+
line = line.strip().split(" ")
|
| 49 |
+
|
| 50 |
+
face = line[0]
|
| 51 |
+
lefteye = line[1]
|
| 52 |
+
righteye = line[2]
|
| 53 |
+
name = line[3]
|
| 54 |
+
gaze2d = line[5]
|
| 55 |
+
label = np.array(gaze2d.split(",")).astype("float")
|
| 56 |
+
label = torch.from_numpy(label).type(torch.FloatTensor)
|
| 57 |
+
|
| 58 |
+
pitch = label[0]* 180 / np.pi
|
| 59 |
+
yaw = label[1]* 180 / np.pi
|
| 60 |
+
|
| 61 |
+
img = Image.open(os.path.join(self.root, face))
|
| 62 |
+
|
| 63 |
+
# fimg = cv2.imread(os.path.join(self.root, face))
|
| 64 |
+
# fimg = cv2.resize(fimg, (448, 448))/255.0
|
| 65 |
+
# fimg = fimg.transpose(2, 0, 1)
|
| 66 |
+
# img=torch.from_numpy(fimg).type(torch.FloatTensor)
|
| 67 |
+
|
| 68 |
+
if self.transform:
|
| 69 |
+
img = self.transform(img)
|
| 70 |
+
|
| 71 |
+
# Bin values
|
| 72 |
+
bins = np.array(range(-1*self.angle, self.angle, self.binwidth))
|
| 73 |
+
binned_pose = np.digitize([pitch, yaw], bins) - 1
|
| 74 |
+
|
| 75 |
+
labels = binned_pose
|
| 76 |
+
cont_labels = torch.FloatTensor([pitch, yaw])
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
return img, labels, cont_labels, name
|
| 80 |
+
|
| 81 |
+
class Mpiigaze(Dataset):
|
| 82 |
+
def __init__(self, pathorg, root, transform, train, angle,fold=0):
|
| 83 |
+
self.transform = transform
|
| 84 |
+
self.root = root
|
| 85 |
+
self.orig_list_len = 0
|
| 86 |
+
self.lines = []
|
| 87 |
+
path=pathorg.copy()
|
| 88 |
+
if train==True:
|
| 89 |
+
path.pop(fold)
|
| 90 |
+
else:
|
| 91 |
+
path=path[fold]
|
| 92 |
+
if isinstance(path, list):
|
| 93 |
+
for i in path:
|
| 94 |
+
with open(i) as f:
|
| 95 |
+
lines = f.readlines()
|
| 96 |
+
lines.pop(0)
|
| 97 |
+
self.orig_list_len += len(lines)
|
| 98 |
+
for line in lines:
|
| 99 |
+
gaze2d = line.strip().split(" ")[7]
|
| 100 |
+
label = np.array(gaze2d.split(",")).astype("float")
|
| 101 |
+
if abs((label[0]*180/np.pi)) <= angle and abs((label[1]*180/np.pi)) <= angle:
|
| 102 |
+
self.lines.append(line)
|
| 103 |
+
else:
|
| 104 |
+
with open(path) as f:
|
| 105 |
+
lines = f.readlines()
|
| 106 |
+
lines.pop(0)
|
| 107 |
+
self.orig_list_len += len(lines)
|
| 108 |
+
for line in lines:
|
| 109 |
+
gaze2d = line.strip().split(" ")[7]
|
| 110 |
+
label = np.array(gaze2d.split(",")).astype("float")
|
| 111 |
+
if abs((label[0]*180/np.pi)) <= 42 and abs((label[1]*180/np.pi)) <= 42:
|
| 112 |
+
self.lines.append(line)
|
| 113 |
+
|
| 114 |
+
print("{} items removed from dataset that have an angle > {}".format(self.orig_list_len-len(self.lines),angle))
|
| 115 |
+
|
| 116 |
+
def __len__(self):
|
| 117 |
+
return len(self.lines)
|
| 118 |
+
|
| 119 |
+
def __getitem__(self, idx):
|
| 120 |
+
line = self.lines[idx]
|
| 121 |
+
line = line.strip().split(" ")
|
| 122 |
+
|
| 123 |
+
name = line[3]
|
| 124 |
+
gaze2d = line[7]
|
| 125 |
+
head2d = line[8]
|
| 126 |
+
lefteye = line[1]
|
| 127 |
+
righteye = line[2]
|
| 128 |
+
face = line[0]
|
| 129 |
+
|
| 130 |
+
label = np.array(gaze2d.split(",")).astype("float")
|
| 131 |
+
label = torch.from_numpy(label).type(torch.FloatTensor)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
pitch = label[0]* 180 / np.pi
|
| 135 |
+
yaw = label[1]* 180 / np.pi
|
| 136 |
+
|
| 137 |
+
img = Image.open(os.path.join(self.root, face))
|
| 138 |
+
|
| 139 |
+
# fimg = cv2.imread(os.path.join(self.root, face))
|
| 140 |
+
# fimg = cv2.resize(fimg, (448, 448))/255.0
|
| 141 |
+
# fimg = fimg.transpose(2, 0, 1)
|
| 142 |
+
# img=torch.from_numpy(fimg).type(torch.FloatTensor)
|
| 143 |
+
|
| 144 |
+
if self.transform:
|
| 145 |
+
img = self.transform(img)
|
| 146 |
+
|
| 147 |
+
# Bin values
|
| 148 |
+
bins = np.array(range(-42, 42,3))
|
| 149 |
+
binned_pose = np.digitize([pitch, yaw], bins) - 1
|
| 150 |
+
|
| 151 |
+
labels = binned_pose
|
| 152 |
+
cont_labels = torch.FloatTensor([pitch, yaw])
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
return img, labels, cont_labels, name
|
| 156 |
+
|
| 157 |
+
|
models/L2CS-Net/l2cs/model.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from torch.autograd import Variable
|
| 4 |
+
import math
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class L2CS(nn.Module):
|
| 9 |
+
def __init__(self, block, layers, num_bins):
|
| 10 |
+
self.inplanes = 64
|
| 11 |
+
super(L2CS, self).__init__()
|
| 12 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,bias=False)
|
| 13 |
+
self.bn1 = nn.BatchNorm2d(64)
|
| 14 |
+
self.relu = nn.ReLU(inplace=True)
|
| 15 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
| 16 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
| 17 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
| 18 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
| 19 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
| 20 |
+
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
| 21 |
+
|
| 22 |
+
self.fc_yaw_gaze = nn.Linear(512 * block.expansion, num_bins)
|
| 23 |
+
self.fc_pitch_gaze = nn.Linear(512 * block.expansion, num_bins)
|
| 24 |
+
|
| 25 |
+
# Vestigial layer from previous experiments
|
| 26 |
+
self.fc_finetune = nn.Linear(512 * block.expansion + 3, 3)
|
| 27 |
+
|
| 28 |
+
for m in self.modules():
|
| 29 |
+
if isinstance(m, nn.Conv2d):
|
| 30 |
+
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
| 31 |
+
m.weight.data.normal_(0, math.sqrt(2. / n))
|
| 32 |
+
elif isinstance(m, nn.BatchNorm2d):
|
| 33 |
+
m.weight.data.fill_(1)
|
| 34 |
+
m.bias.data.zero_()
|
| 35 |
+
|
| 36 |
+
def _make_layer(self, block, planes, blocks, stride=1):
|
| 37 |
+
downsample = None
|
| 38 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
| 39 |
+
downsample = nn.Sequential(
|
| 40 |
+
nn.Conv2d(self.inplanes, planes * block.expansion,
|
| 41 |
+
kernel_size=1, stride=stride, bias=False),
|
| 42 |
+
nn.BatchNorm2d(planes * block.expansion),
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
layers = []
|
| 46 |
+
layers.append(block(self.inplanes, planes, stride, downsample))
|
| 47 |
+
self.inplanes = planes * block.expansion
|
| 48 |
+
for i in range(1, blocks):
|
| 49 |
+
layers.append(block(self.inplanes, planes))
|
| 50 |
+
|
| 51 |
+
return nn.Sequential(*layers)
|
| 52 |
+
|
| 53 |
+
def forward(self, x):
|
| 54 |
+
x = self.conv1(x)
|
| 55 |
+
x = self.bn1(x)
|
| 56 |
+
x = self.relu(x)
|
| 57 |
+
x = self.maxpool(x)
|
| 58 |
+
|
| 59 |
+
x = self.layer1(x)
|
| 60 |
+
x = self.layer2(x)
|
| 61 |
+
x = self.layer3(x)
|
| 62 |
+
x = self.layer4(x)
|
| 63 |
+
x = self.avgpool(x)
|
| 64 |
+
x = x.view(x.size(0), -1)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
# gaze
|
| 68 |
+
pre_yaw_gaze = self.fc_yaw_gaze(x)
|
| 69 |
+
pre_pitch_gaze = self.fc_pitch_gaze(x)
|
| 70 |
+
return pre_yaw_gaze, pre_pitch_gaze
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
|
models/L2CS-Net/l2cs/pipeline.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pathlib
|
| 2 |
+
from typing import Union
|
| 3 |
+
|
| 4 |
+
import cv2
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
from face_detection import RetinaFace
|
| 10 |
+
|
| 11 |
+
from .utils import prep_input_numpy, getArch
|
| 12 |
+
from .results import GazeResultContainer
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class Pipeline:
|
| 16 |
+
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
weights: pathlib.Path,
|
| 20 |
+
arch: str,
|
| 21 |
+
device: str = 'cpu',
|
| 22 |
+
include_detector:bool = True,
|
| 23 |
+
confidence_threshold:float = 0.5
|
| 24 |
+
):
|
| 25 |
+
|
| 26 |
+
# Save input parameters
|
| 27 |
+
self.weights = weights
|
| 28 |
+
self.include_detector = include_detector
|
| 29 |
+
self.device = device
|
| 30 |
+
self.confidence_threshold = confidence_threshold
|
| 31 |
+
|
| 32 |
+
# Create L2CS model
|
| 33 |
+
self.model = getArch(arch, 90)
|
| 34 |
+
self.model.load_state_dict(torch.load(self.weights, map_location=device))
|
| 35 |
+
self.model.to(self.device)
|
| 36 |
+
self.model.eval()
|
| 37 |
+
|
| 38 |
+
# Create RetinaFace if requested
|
| 39 |
+
if self.include_detector:
|
| 40 |
+
|
| 41 |
+
if device.type == 'cpu':
|
| 42 |
+
self.detector = RetinaFace()
|
| 43 |
+
else:
|
| 44 |
+
self.detector = RetinaFace(gpu_id=device.index)
|
| 45 |
+
|
| 46 |
+
self.softmax = nn.Softmax(dim=1)
|
| 47 |
+
self.idx_tensor = [idx for idx in range(90)]
|
| 48 |
+
self.idx_tensor = torch.FloatTensor(self.idx_tensor).to(self.device)
|
| 49 |
+
|
| 50 |
+
def step(self, frame: np.ndarray) -> GazeResultContainer:
|
| 51 |
+
|
| 52 |
+
# Creating containers
|
| 53 |
+
face_imgs = []
|
| 54 |
+
bboxes = []
|
| 55 |
+
landmarks = []
|
| 56 |
+
scores = []
|
| 57 |
+
|
| 58 |
+
if self.include_detector:
|
| 59 |
+
faces = self.detector(frame)
|
| 60 |
+
|
| 61 |
+
if faces is not None:
|
| 62 |
+
for box, landmark, score in faces:
|
| 63 |
+
|
| 64 |
+
# Apply threshold
|
| 65 |
+
if score < self.confidence_threshold:
|
| 66 |
+
continue
|
| 67 |
+
|
| 68 |
+
# Extract safe min and max of x,y
|
| 69 |
+
x_min=int(box[0])
|
| 70 |
+
if x_min < 0:
|
| 71 |
+
x_min = 0
|
| 72 |
+
y_min=int(box[1])
|
| 73 |
+
if y_min < 0:
|
| 74 |
+
y_min = 0
|
| 75 |
+
x_max=int(box[2])
|
| 76 |
+
y_max=int(box[3])
|
| 77 |
+
|
| 78 |
+
# Crop image
|
| 79 |
+
img = frame[y_min:y_max, x_min:x_max]
|
| 80 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 81 |
+
img = cv2.resize(img, (224, 224))
|
| 82 |
+
face_imgs.append(img)
|
| 83 |
+
|
| 84 |
+
# Save data
|
| 85 |
+
bboxes.append(box)
|
| 86 |
+
landmarks.append(landmark)
|
| 87 |
+
scores.append(score)
|
| 88 |
+
|
| 89 |
+
# Predict gaze
|
| 90 |
+
pitch, yaw = self.predict_gaze(np.stack(face_imgs))
|
| 91 |
+
|
| 92 |
+
else:
|
| 93 |
+
|
| 94 |
+
pitch = np.empty((0,1))
|
| 95 |
+
yaw = np.empty((0,1))
|
| 96 |
+
|
| 97 |
+
else:
|
| 98 |
+
pitch, yaw = self.predict_gaze(frame)
|
| 99 |
+
|
| 100 |
+
# Save data
|
| 101 |
+
results = GazeResultContainer(
|
| 102 |
+
pitch=pitch,
|
| 103 |
+
yaw=yaw,
|
| 104 |
+
bboxes=np.stack(bboxes),
|
| 105 |
+
landmarks=np.stack(landmarks),
|
| 106 |
+
scores=np.stack(scores)
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
return results
|
| 110 |
+
|
| 111 |
+
def predict_gaze(self, frame: Union[np.ndarray, torch.Tensor]):
|
| 112 |
+
|
| 113 |
+
# Prepare input
|
| 114 |
+
if isinstance(frame, np.ndarray):
|
| 115 |
+
img = prep_input_numpy(frame, self.device)
|
| 116 |
+
elif isinstance(frame, torch.Tensor):
|
| 117 |
+
img = frame
|
| 118 |
+
else:
|
| 119 |
+
raise RuntimeError("Invalid dtype for input")
|
| 120 |
+
|
| 121 |
+
# Predict
|
| 122 |
+
gaze_pitch, gaze_yaw = self.model(img)
|
| 123 |
+
pitch_predicted = self.softmax(gaze_pitch)
|
| 124 |
+
yaw_predicted = self.softmax(gaze_yaw)
|
| 125 |
+
|
| 126 |
+
# Get continuous predictions in degrees.
|
| 127 |
+
pitch_predicted = torch.sum(pitch_predicted.data * self.idx_tensor, dim=1) * 4 - 180
|
| 128 |
+
yaw_predicted = torch.sum(yaw_predicted.data * self.idx_tensor, dim=1) * 4 - 180
|
| 129 |
+
|
| 130 |
+
pitch_predicted= pitch_predicted.cpu().detach().numpy()* np.pi/180.0
|
| 131 |
+
yaw_predicted= yaw_predicted.cpu().detach().numpy()* np.pi/180.0
|
| 132 |
+
|
| 133 |
+
return pitch_predicted, yaw_predicted
|
models/L2CS-Net/l2cs/results.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
@dataclass
|
| 5 |
+
class GazeResultContainer:
|
| 6 |
+
|
| 7 |
+
pitch: np.ndarray
|
| 8 |
+
yaw: np.ndarray
|
| 9 |
+
bboxes: np.ndarray
|
| 10 |
+
landmarks: np.ndarray
|
| 11 |
+
scores: np.ndarray
|
models/L2CS-Net/l2cs/utils.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os
|
| 3 |
+
import math
|
| 4 |
+
from math import cos, sin
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
import subprocess
|
| 7 |
+
import re
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import scipy.io as sio
|
| 13 |
+
import cv2
|
| 14 |
+
import torchvision
|
| 15 |
+
from torchvision import transforms
|
| 16 |
+
|
| 17 |
+
from .model import L2CS
|
| 18 |
+
|
| 19 |
+
transformations = transforms.Compose([
|
| 20 |
+
transforms.ToPILImage(),
|
| 21 |
+
transforms.Resize(448),
|
| 22 |
+
transforms.ToTensor(),
|
| 23 |
+
transforms.Normalize(
|
| 24 |
+
mean=[0.485, 0.456, 0.406],
|
| 25 |
+
std=[0.229, 0.224, 0.225]
|
| 26 |
+
)
|
| 27 |
+
])
|
| 28 |
+
|
| 29 |
+
def atoi(text):
|
| 30 |
+
return int(text) if text.isdigit() else text
|
| 31 |
+
|
| 32 |
+
def natural_keys(text):
|
| 33 |
+
'''
|
| 34 |
+
alist.sort(key=natural_keys) sorts in human order
|
| 35 |
+
http://nedbatchelder.com/blog/200712/human_sorting.html
|
| 36 |
+
(See Toothy's implementation in the comments)
|
| 37 |
+
'''
|
| 38 |
+
return [ atoi(c) for c in re.split(r'(\d+)', text) ]
|
| 39 |
+
|
| 40 |
+
def prep_input_numpy(img:np.ndarray, device:str):
|
| 41 |
+
"""Preparing a Numpy Array as input to L2CS-Net."""
|
| 42 |
+
|
| 43 |
+
if len(img.shape) == 4:
|
| 44 |
+
imgs = []
|
| 45 |
+
for im in img:
|
| 46 |
+
imgs.append(transformations(im))
|
| 47 |
+
img = torch.stack(imgs)
|
| 48 |
+
else:
|
| 49 |
+
img = transformations(img)
|
| 50 |
+
|
| 51 |
+
img = img.to(device)
|
| 52 |
+
|
| 53 |
+
if len(img.shape) == 3:
|
| 54 |
+
img = img.unsqueeze(0)
|
| 55 |
+
|
| 56 |
+
return img
|
| 57 |
+
|
| 58 |
+
def gazeto3d(gaze):
|
| 59 |
+
gaze_gt = np.zeros([3])
|
| 60 |
+
gaze_gt[0] = -np.cos(gaze[1]) * np.sin(gaze[0])
|
| 61 |
+
gaze_gt[1] = -np.sin(gaze[1])
|
| 62 |
+
gaze_gt[2] = -np.cos(gaze[1]) * np.cos(gaze[0])
|
| 63 |
+
return gaze_gt
|
| 64 |
+
|
| 65 |
+
def angular(gaze, label):
|
| 66 |
+
total = np.sum(gaze * label)
|
| 67 |
+
return np.arccos(min(total/(np.linalg.norm(gaze)* np.linalg.norm(label)), 0.9999999))*180/np.pi
|
| 68 |
+
|
| 69 |
+
def select_device(device='', batch_size=None):
|
| 70 |
+
# device = 'cpu' or '0' or '0,1,2,3'
|
| 71 |
+
s = f'YOLOv3 🚀 {git_describe() or date_modified()} torch {torch.__version__} ' # string
|
| 72 |
+
cpu = device.lower() == 'cpu'
|
| 73 |
+
if cpu:
|
| 74 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False
|
| 75 |
+
elif device: # non-cpu device requested
|
| 76 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable
|
| 77 |
+
# assert torch.cuda.is_available(), f'CUDA unavailable, invalid device {device} requested' # check availability
|
| 78 |
+
|
| 79 |
+
cuda = not cpu and torch.cuda.is_available()
|
| 80 |
+
if cuda:
|
| 81 |
+
devices = device.split(',') if device else range(torch.cuda.device_count()) # i.e. 0,1,6,7
|
| 82 |
+
n = len(devices) # device count
|
| 83 |
+
if n > 1 and batch_size: # check batch_size is divisible by device_count
|
| 84 |
+
assert batch_size % n == 0, f'batch-size {batch_size} not multiple of GPU count {n}'
|
| 85 |
+
space = ' ' * len(s)
|
| 86 |
+
for i, d in enumerate(devices):
|
| 87 |
+
p = torch.cuda.get_device_properties(i)
|
| 88 |
+
s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / 1024 ** 2}MB)\n" # bytes to MB
|
| 89 |
+
else:
|
| 90 |
+
s += 'CPU\n'
|
| 91 |
+
|
| 92 |
+
return torch.device('cuda:0' if cuda else 'cpu')
|
| 93 |
+
|
| 94 |
+
def spherical2cartesial(x):
|
| 95 |
+
|
| 96 |
+
output = torch.zeros(x.size(0),3)
|
| 97 |
+
output[:,2] = -torch.cos(x[:,1])*torch.cos(x[:,0])
|
| 98 |
+
output[:,0] = torch.cos(x[:,1])*torch.sin(x[:,0])
|
| 99 |
+
output[:,1] = torch.sin(x[:,1])
|
| 100 |
+
|
| 101 |
+
return output
|
| 102 |
+
|
| 103 |
+
def compute_angular_error(input,target):
|
| 104 |
+
|
| 105 |
+
input = spherical2cartesial(input)
|
| 106 |
+
target = spherical2cartesial(target)
|
| 107 |
+
|
| 108 |
+
input = input.view(-1,3,1)
|
| 109 |
+
target = target.view(-1,1,3)
|
| 110 |
+
output_dot = torch.bmm(target,input)
|
| 111 |
+
output_dot = output_dot.view(-1)
|
| 112 |
+
output_dot = torch.acos(output_dot)
|
| 113 |
+
output_dot = output_dot.data
|
| 114 |
+
output_dot = 180*torch.mean(output_dot)/math.pi
|
| 115 |
+
return output_dot
|
| 116 |
+
|
| 117 |
+
def softmax_temperature(tensor, temperature):
|
| 118 |
+
result = torch.exp(tensor / temperature)
|
| 119 |
+
result = torch.div(result, torch.sum(result, 1).unsqueeze(1).expand_as(result))
|
| 120 |
+
return result
|
| 121 |
+
|
| 122 |
+
def git_describe(path=Path(__file__).parent): # path must be a directory
|
| 123 |
+
# return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe
|
| 124 |
+
s = f'git -C {path} describe --tags --long --always'
|
| 125 |
+
try:
|
| 126 |
+
return subprocess.check_output(s, shell=True, stderr=subprocess.STDOUT).decode()[:-1]
|
| 127 |
+
except subprocess.CalledProcessError as e:
|
| 128 |
+
return '' # not a git repository
|
| 129 |
+
|
| 130 |
+
def getArch(arch,bins):
|
| 131 |
+
# Base network structure
|
| 132 |
+
if arch == 'ResNet18':
|
| 133 |
+
model = L2CS( torchvision.models.resnet.BasicBlock,[2, 2, 2, 2], bins)
|
| 134 |
+
elif arch == 'ResNet34':
|
| 135 |
+
model = L2CS( torchvision.models.resnet.BasicBlock,[3, 4, 6, 3], bins)
|
| 136 |
+
elif arch == 'ResNet101':
|
| 137 |
+
model = L2CS( torchvision.models.resnet.Bottleneck,[3, 4, 23, 3], bins)
|
| 138 |
+
elif arch == 'ResNet152':
|
| 139 |
+
model = L2CS( torchvision.models.resnet.Bottleneck,[3, 8, 36, 3], bins)
|
| 140 |
+
else:
|
| 141 |
+
if arch != 'ResNet50':
|
| 142 |
+
print('Invalid value for architecture is passed! '
|
| 143 |
+
'The default value of ResNet50 will be used instead!')
|
| 144 |
+
model = L2CS( torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], bins)
|
| 145 |
+
return model
|
models/L2CS-Net/l2cs/vis.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import numpy as np
|
| 3 |
+
from .results import GazeResultContainer
|
| 4 |
+
|
| 5 |
+
def draw_gaze(a,b,c,d,image_in, pitchyaw, thickness=2, color=(255, 255, 0),sclae=2.0):
|
| 6 |
+
"""Draw gaze angle on given image with a given eye positions."""
|
| 7 |
+
image_out = image_in
|
| 8 |
+
(h, w) = image_in.shape[:2]
|
| 9 |
+
length = c
|
| 10 |
+
pos = (int(a+c / 2.0), int(b+d / 2.0))
|
| 11 |
+
if len(image_out.shape) == 2 or image_out.shape[2] == 1:
|
| 12 |
+
image_out = cv2.cvtColor(image_out, cv2.COLOR_GRAY2BGR)
|
| 13 |
+
dx = -length * np.sin(pitchyaw[0]) * np.cos(pitchyaw[1])
|
| 14 |
+
dy = -length * np.sin(pitchyaw[1])
|
| 15 |
+
cv2.arrowedLine(image_out, tuple(np.round(pos).astype(np.int32)),
|
| 16 |
+
tuple(np.round([pos[0] + dx, pos[1] + dy]).astype(int)), color,
|
| 17 |
+
thickness, cv2.LINE_AA, tipLength=0.18)
|
| 18 |
+
return image_out
|
| 19 |
+
|
| 20 |
+
def draw_bbox(frame: np.ndarray, bbox: np.ndarray):
|
| 21 |
+
|
| 22 |
+
x_min=int(bbox[0])
|
| 23 |
+
if x_min < 0:
|
| 24 |
+
x_min = 0
|
| 25 |
+
y_min=int(bbox[1])
|
| 26 |
+
if y_min < 0:
|
| 27 |
+
y_min = 0
|
| 28 |
+
x_max=int(bbox[2])
|
| 29 |
+
y_max=int(bbox[3])
|
| 30 |
+
|
| 31 |
+
cv2.rectangle(frame, (x_min, y_min), (x_max, y_max), (0,255,0), 1)
|
| 32 |
+
|
| 33 |
+
return frame
|
| 34 |
+
|
| 35 |
+
def render(frame: np.ndarray, results: GazeResultContainer):
|
| 36 |
+
|
| 37 |
+
# Draw bounding boxes
|
| 38 |
+
for bbox in results.bboxes:
|
| 39 |
+
frame = draw_bbox(frame, bbox)
|
| 40 |
+
|
| 41 |
+
# Draw Gaze
|
| 42 |
+
for i in range(results.pitch.shape[0]):
|
| 43 |
+
|
| 44 |
+
bbox = results.bboxes[i]
|
| 45 |
+
pitch = results.pitch[i]
|
| 46 |
+
yaw = results.yaw[i]
|
| 47 |
+
|
| 48 |
+
# Extract safe min and max of x,y
|
| 49 |
+
x_min=int(bbox[0])
|
| 50 |
+
if x_min < 0:
|
| 51 |
+
x_min = 0
|
| 52 |
+
y_min=int(bbox[1])
|
| 53 |
+
if y_min < 0:
|
| 54 |
+
y_min = 0
|
| 55 |
+
x_max=int(bbox[2])
|
| 56 |
+
y_max=int(bbox[3])
|
| 57 |
+
|
| 58 |
+
# Compute sizes
|
| 59 |
+
bbox_width = x_max - x_min
|
| 60 |
+
bbox_height = y_max - y_min
|
| 61 |
+
|
| 62 |
+
draw_gaze(x_min,y_min,bbox_width, bbox_height,frame,(pitch,yaw),color=(0,0,255))
|
| 63 |
+
|
| 64 |
+
return frame
|
models/L2CS-Net/leave_one_out_eval.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import argparse
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def parse_args():
|
| 7 |
+
"""Parse input arguments."""
|
| 8 |
+
parser = argparse.ArgumentParser(
|
| 9 |
+
description='gaze estimation using binned loss function.')
|
| 10 |
+
parser.add_argument(
|
| 11 |
+
'--evalpath', dest='evalpath', help='path for evaluating gaze test.',
|
| 12 |
+
default="evaluation\L2CS-gaze360-_standard-10", type=str)
|
| 13 |
+
parser.add_argument(
|
| 14 |
+
'--respath', dest='respath', help='path for saving result.',
|
| 15 |
+
default="evaluation\L2CS-gaze360-_standard-10", type=str)
|
| 16 |
+
|
| 17 |
+
if __name__ == '__main__':
|
| 18 |
+
|
| 19 |
+
args = parse_args()
|
| 20 |
+
evalpath =args.evalpath
|
| 21 |
+
respath=args.respath
|
| 22 |
+
if not os.path.exist(respath):
|
| 23 |
+
os.makedirs(respath)
|
| 24 |
+
with open(os.path.join(respath,"avg.log"), 'w') as outfile:
|
| 25 |
+
outfile.write("Average equal\n")
|
| 26 |
+
|
| 27 |
+
min=10.0
|
| 28 |
+
dirlist = os.listdir(evalpath)
|
| 29 |
+
dirlist.sort()
|
| 30 |
+
l=0.0
|
| 31 |
+
for j in range(50):
|
| 32 |
+
j=20
|
| 33 |
+
avg=0.0
|
| 34 |
+
h=j+3
|
| 35 |
+
for i in dirlist:
|
| 36 |
+
with open(evalpath+"/"+i+"/mpiigaze_binned.log") as myfile:
|
| 37 |
+
|
| 38 |
+
x=list(myfile)[h]
|
| 39 |
+
str1 = ""
|
| 40 |
+
|
| 41 |
+
# traverse in the string
|
| 42 |
+
for ele in x:
|
| 43 |
+
str1 += ele
|
| 44 |
+
split_string = str1.split("MAE:",1)[1]
|
| 45 |
+
avg+=float(split_string)
|
| 46 |
+
|
| 47 |
+
avg=avg/15.0
|
| 48 |
+
if avg<min:
|
| 49 |
+
min=avg
|
| 50 |
+
l=j+1
|
| 51 |
+
outfile.write("epoch"+str(j+1)+"= "+str(avg)+"\n")
|
| 52 |
+
|
| 53 |
+
outfile.write("min angular error equal= "+str(min)+"at epoch= "+str(l)+"\n")
|
| 54 |
+
print(min)
|
models/L2CS-Net/models/L2CSNet_gaze360.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8a7f3480d868dd48261e1d59f915b0ef0bb33ea12ea00938fb2168f212080665
|
| 3 |
+
size 95849977
|
models/L2CS-Net/models/README.md
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Path to pre-trained models
|
models/L2CS-Net/pyproject.toml
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "l2cs"
|
| 3 |
+
version = "0.0.1"
|
| 4 |
+
description = "The official PyTorch implementation of L2CS-Net for gaze estimation and tracking"
|
| 5 |
+
authors = [
|
| 6 |
+
{name = "Ahmed Abderlrahman"},
|
| 7 |
+
{name = "Thorsten Hempel"}
|
| 8 |
+
]
|
| 9 |
+
license = {file = "LICENSE.txt"}
|
| 10 |
+
readme = "README.md"
|
| 11 |
+
requires-python = ">3.6"
|
| 12 |
+
|
| 13 |
+
keywords = ["gaze", "estimation", "eye-tracking", "deep-learning", "pytorch"]
|
| 14 |
+
|
| 15 |
+
classifiers = [
|
| 16 |
+
"Programming Language :: Python :: 3"
|
| 17 |
+
]
|
| 18 |
+
|
| 19 |
+
dependencies = [
|
| 20 |
+
'matplotlib>=3.3.4',
|
| 21 |
+
'numpy>=1.19.5',
|
| 22 |
+
'opencv-python>=4.5.5',
|
| 23 |
+
'pandas>=1.1.5',
|
| 24 |
+
'Pillow>=8.4.0',
|
| 25 |
+
'scipy>=1.5.4',
|
| 26 |
+
'torch>=1.10.1',
|
| 27 |
+
'torchvision>=0.11.2',
|
| 28 |
+
'face_detection@git+https://github.com/elliottzheng/face-detection'
|
| 29 |
+
]
|
| 30 |
+
|
| 31 |
+
[project.urls]
|
| 32 |
+
homepath = "https://github.com/Ahmednull/L2CS-Net"
|
| 33 |
+
repository = "https://github.com/Ahmednull/L2CS-Net"
|
| 34 |
+
|
| 35 |
+
[build-system]
|
| 36 |
+
requires = ["setuptools", "wheel"]
|
| 37 |
+
build-backend = "setuptools.build_meta"
|
| 38 |
+
|
| 39 |
+
# https://setuptools.pypa.io/en/stable/userguide/datafiles.html
|
| 40 |
+
[tool.setuptools]
|
| 41 |
+
include-package-data = true
|
| 42 |
+
|
| 43 |
+
[tool.setuptools.packages.find]
|
| 44 |
+
where = ["."]
|
models/L2CS-Net/test.py
ADDED
|
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, argparse
|
| 2 |
+
import numpy as np
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
from torch.autograd import Variable
|
| 7 |
+
from torch.utils.data import DataLoader
|
| 8 |
+
from torchvision import transforms
|
| 9 |
+
import torch.backends.cudnn as cudnn
|
| 10 |
+
import torchvision
|
| 11 |
+
|
| 12 |
+
from l2cs import select_device, natural_keys, gazeto3d, angular, getArch, L2CS, Gaze360, Mpiigaze
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def parse_args():
|
| 16 |
+
"""Parse input arguments."""
|
| 17 |
+
parser = argparse.ArgumentParser(
|
| 18 |
+
description='Gaze estimation using L2CSNet .')
|
| 19 |
+
# Gaze360
|
| 20 |
+
parser.add_argument(
|
| 21 |
+
'--gaze360image_dir', dest='gaze360image_dir', help='Directory path for gaze images.',
|
| 22 |
+
default='datasets/Gaze360/Image', type=str)
|
| 23 |
+
parser.add_argument(
|
| 24 |
+
'--gaze360label_dir', dest='gaze360label_dir', help='Directory path for gaze labels.',
|
| 25 |
+
default='datasets/Gaze360/Label/test.label', type=str)
|
| 26 |
+
# mpiigaze
|
| 27 |
+
parser.add_argument(
|
| 28 |
+
'--gazeMpiimage_dir', dest='gazeMpiimage_dir', help='Directory path for gaze images.',
|
| 29 |
+
default='datasets/MPIIFaceGaze/Image', type=str)
|
| 30 |
+
parser.add_argument(
|
| 31 |
+
'--gazeMpiilabel_dir', dest='gazeMpiilabel_dir', help='Directory path for gaze labels.',
|
| 32 |
+
default='datasets/MPIIFaceGaze/Label', type=str)
|
| 33 |
+
# Important args -------------------------------------------------------------------------------------------------------
|
| 34 |
+
# ----------------------------------------------------------------------------------------------------------------------
|
| 35 |
+
parser.add_argument(
|
| 36 |
+
'--dataset', dest='dataset', help='gaze360, mpiigaze',
|
| 37 |
+
default= "gaze360", type=str)
|
| 38 |
+
parser.add_argument(
|
| 39 |
+
'--snapshot', dest='snapshot', help='Path to the folder contains models.',
|
| 40 |
+
default='output/snapshots/L2CS-gaze360-_loader-180-4-lr', type=str)
|
| 41 |
+
parser.add_argument(
|
| 42 |
+
'--evalpath', dest='evalpath', help='path for the output evaluating gaze test.',
|
| 43 |
+
default="evaluation/L2CS-gaze360-_loader-180-4-lr", type=str)
|
| 44 |
+
parser.add_argument(
|
| 45 |
+
'--gpu',dest='gpu_id', help='GPU device id to use [0]',
|
| 46 |
+
default="0", type=str)
|
| 47 |
+
parser.add_argument(
|
| 48 |
+
'--batch_size', dest='batch_size', help='Batch size.',
|
| 49 |
+
default=100, type=int)
|
| 50 |
+
parser.add_argument(
|
| 51 |
+
'--arch', dest='arch', help='Network architecture, can be: ResNet18, ResNet34, [ResNet50], ''ResNet101, ResNet152, Squeezenet_1_0, Squeezenet_1_1, MobileNetV2',
|
| 52 |
+
default='ResNet50', type=str)
|
| 53 |
+
# ---------------------------------------------------------------------------------------------------------------------
|
| 54 |
+
# Important args ------------------------------------------------------------------------------------------------------
|
| 55 |
+
args = parser.parse_args()
|
| 56 |
+
return args
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def getArch(arch,bins):
|
| 60 |
+
# Base network structure
|
| 61 |
+
if arch == 'ResNet18':
|
| 62 |
+
model = L2CS( torchvision.models.resnet.BasicBlock,[2, 2, 2, 2], bins)
|
| 63 |
+
elif arch == 'ResNet34':
|
| 64 |
+
model = L2CS( torchvision.models.resnet.BasicBlock,[3, 4, 6, 3], bins)
|
| 65 |
+
elif arch == 'ResNet101':
|
| 66 |
+
model = L2CS( torchvision.models.resnet.Bottleneck,[3, 4, 23, 3], bins)
|
| 67 |
+
elif arch == 'ResNet152':
|
| 68 |
+
model = L2CS( torchvision.models.resnet.Bottleneck,[3, 8, 36, 3], bins)
|
| 69 |
+
else:
|
| 70 |
+
if arch != 'ResNet50':
|
| 71 |
+
print('Invalid value for architecture is passed! '
|
| 72 |
+
'The default value of ResNet50 will be used instead!')
|
| 73 |
+
model = L2CS( torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], bins)
|
| 74 |
+
return model
|
| 75 |
+
|
| 76 |
+
if __name__ == '__main__':
|
| 77 |
+
args = parse_args()
|
| 78 |
+
cudnn.enabled = True
|
| 79 |
+
gpu = select_device(args.gpu_id, batch_size=args.batch_size)
|
| 80 |
+
batch_size=args.batch_size
|
| 81 |
+
arch=args.arch
|
| 82 |
+
data_set=args.dataset
|
| 83 |
+
evalpath =args.evalpath
|
| 84 |
+
snapshot_path = args.snapshot
|
| 85 |
+
bins=args.bins
|
| 86 |
+
angle=args.angle
|
| 87 |
+
bin_width=args.bin_width
|
| 88 |
+
|
| 89 |
+
transformations = transforms.Compose([
|
| 90 |
+
transforms.Resize(448),
|
| 91 |
+
transforms.ToTensor(),
|
| 92 |
+
transforms.Normalize(
|
| 93 |
+
mean=[0.485, 0.456, 0.406],
|
| 94 |
+
std=[0.229, 0.224, 0.225]
|
| 95 |
+
)
|
| 96 |
+
])
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
if data_set=="gaze360":
|
| 101 |
+
|
| 102 |
+
gaze_dataset=Gaze360(args.gaze360label_dir,args.gaze360image_dir, transformations, 180, 4, train=False)
|
| 103 |
+
test_loader = torch.utils.data.DataLoader(
|
| 104 |
+
dataset=gaze_dataset,
|
| 105 |
+
batch_size=batch_size,
|
| 106 |
+
shuffle=False,
|
| 107 |
+
num_workers=4,
|
| 108 |
+
pin_memory=True)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
if not os.path.exists(evalpath):
|
| 113 |
+
os.makedirs(evalpath)
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
# list all epochs for testing
|
| 117 |
+
folder = os.listdir(snapshot_path)
|
| 118 |
+
folder.sort(key=natural_keys)
|
| 119 |
+
softmax = nn.Softmax(dim=1)
|
| 120 |
+
with open(os.path.join(evalpath,data_set+".log"), 'w') as outfile:
|
| 121 |
+
configuration = f"\ntest configuration = gpu_id={gpu}, batch_size={batch_size}, model_arch={arch}\nStart testing dataset={data_set}----------------------------------------\n"
|
| 122 |
+
print(configuration)
|
| 123 |
+
outfile.write(configuration)
|
| 124 |
+
epoch_list=[]
|
| 125 |
+
avg_yaw=[]
|
| 126 |
+
avg_pitch=[]
|
| 127 |
+
avg_MAE=[]
|
| 128 |
+
for epochs in folder:
|
| 129 |
+
# Base network structure
|
| 130 |
+
model=getArch(arch, 90)
|
| 131 |
+
saved_state_dict = torch.load(os.path.join(snapshot_path, epochs))
|
| 132 |
+
model.load_state_dict(saved_state_dict)
|
| 133 |
+
model.cuda(gpu)
|
| 134 |
+
model.eval()
|
| 135 |
+
total = 0
|
| 136 |
+
idx_tensor = [idx for idx in range(90)]
|
| 137 |
+
idx_tensor = torch.FloatTensor(idx_tensor).cuda(gpu)
|
| 138 |
+
avg_error = .0
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
with torch.no_grad():
|
| 142 |
+
for j, (images, labels, cont_labels, name) in enumerate(test_loader):
|
| 143 |
+
images = Variable(images).cuda(gpu)
|
| 144 |
+
total += cont_labels.size(0)
|
| 145 |
+
|
| 146 |
+
label_pitch = cont_labels[:,0].float()*np.pi/180
|
| 147 |
+
label_yaw = cont_labels[:,1].float()*np.pi/180
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
gaze_pitch, gaze_yaw = model(images)
|
| 151 |
+
|
| 152 |
+
# Binned predictions
|
| 153 |
+
_, pitch_bpred = torch.max(gaze_pitch.data, 1)
|
| 154 |
+
_, yaw_bpred = torch.max(gaze_yaw.data, 1)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
# Continuous predictions
|
| 158 |
+
pitch_predicted = softmax(gaze_pitch)
|
| 159 |
+
yaw_predicted = softmax(gaze_yaw)
|
| 160 |
+
|
| 161 |
+
# mapping from binned (0 to 28) to angels (-180 to 180)
|
| 162 |
+
pitch_predicted = torch.sum(pitch_predicted * idx_tensor, 1).cpu() * 4 - 180
|
| 163 |
+
yaw_predicted = torch.sum(yaw_predicted * idx_tensor, 1).cpu() * 4 - 180
|
| 164 |
+
|
| 165 |
+
pitch_predicted = pitch_predicted*np.pi/180
|
| 166 |
+
yaw_predicted = yaw_predicted*np.pi/180
|
| 167 |
+
|
| 168 |
+
for p,y,pl,yl in zip(pitch_predicted,yaw_predicted,label_pitch,label_yaw):
|
| 169 |
+
avg_error += angular(gazeto3d([p,y]), gazeto3d([pl,yl]))
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
x = ''.join(filter(lambda i: i.isdigit(), epochs))
|
| 174 |
+
epoch_list.append(x)
|
| 175 |
+
avg_MAE.append(avg_error/total)
|
| 176 |
+
loger = f"[{epochs}---{args.dataset}] Total Num:{total},MAE:{avg_error/total}\n"
|
| 177 |
+
outfile.write(loger)
|
| 178 |
+
print(loger)
|
| 179 |
+
|
| 180 |
+
fig = plt.figure(figsize=(14, 8))
|
| 181 |
+
plt.xlabel('epoch')
|
| 182 |
+
plt.ylabel('avg')
|
| 183 |
+
plt.title('Gaze angular error')
|
| 184 |
+
plt.legend()
|
| 185 |
+
plt.plot(epoch_list, avg_MAE, color='k', label='mae')
|
| 186 |
+
fig.savefig(os.path.join(evalpath,data_set+".png"), format='png')
|
| 187 |
+
plt.show()
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
elif data_set=="mpiigaze":
|
| 192 |
+
model_used=getArch(arch, bins)
|
| 193 |
+
|
| 194 |
+
for fold in range(15):
|
| 195 |
+
folder = os.listdir(args.gazeMpiilabel_dir)
|
| 196 |
+
folder.sort()
|
| 197 |
+
testlabelpathombined = [os.path.join(args.gazeMpiilabel_dir, j) for j in folder]
|
| 198 |
+
gaze_dataset=Mpiigaze(testlabelpathombined,args.gazeMpiimage_dir, transformations, False, angle, fold)
|
| 199 |
+
|
| 200 |
+
test_loader = torch.utils.data.DataLoader(
|
| 201 |
+
dataset=gaze_dataset,
|
| 202 |
+
batch_size=batch_size,
|
| 203 |
+
shuffle=True,
|
| 204 |
+
num_workers=4,
|
| 205 |
+
pin_memory=True)
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
if not os.path.exists(os.path.join(evalpath, f"fold"+str(fold))):
|
| 209 |
+
os.makedirs(os.path.join(evalpath, f"fold"+str(fold)))
|
| 210 |
+
|
| 211 |
+
# list all epochs for testing
|
| 212 |
+
folder = os.listdir(os.path.join(snapshot_path,"fold"+str(fold)))
|
| 213 |
+
folder.sort(key=natural_keys)
|
| 214 |
+
|
| 215 |
+
softmax = nn.Softmax(dim=1)
|
| 216 |
+
with open(os.path.join(evalpath, os.path.join("fold"+str(fold), data_set+".log")), 'w') as outfile:
|
| 217 |
+
configuration = f"\ntest configuration equal gpu_id={gpu}, batch_size={batch_size}, model_arch={arch}\nStart testing dataset={data_set}, fold={fold}---------------------------------------\n"
|
| 218 |
+
print(configuration)
|
| 219 |
+
outfile.write(configuration)
|
| 220 |
+
epoch_list=[]
|
| 221 |
+
avg_MAE=[]
|
| 222 |
+
for epochs in folder:
|
| 223 |
+
model=model_used
|
| 224 |
+
saved_state_dict = torch.load(os.path.join(snapshot_path+"/fold"+str(fold),epochs))
|
| 225 |
+
model= nn.DataParallel(model,device_ids=[0])
|
| 226 |
+
model.load_state_dict(saved_state_dict)
|
| 227 |
+
model.cuda(gpu)
|
| 228 |
+
model.eval()
|
| 229 |
+
total = 0
|
| 230 |
+
idx_tensor = [idx for idx in range(28)]
|
| 231 |
+
idx_tensor = torch.FloatTensor(idx_tensor).cuda(gpu)
|
| 232 |
+
avg_error = .0
|
| 233 |
+
with torch.no_grad():
|
| 234 |
+
for j, (images, labels, cont_labels, name) in enumerate(test_loader):
|
| 235 |
+
images = Variable(images).cuda(gpu)
|
| 236 |
+
total += cont_labels.size(0)
|
| 237 |
+
|
| 238 |
+
label_pitch = cont_labels[:,0].float()*np.pi/180
|
| 239 |
+
label_yaw = cont_labels[:,1].float()*np.pi/180
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
gaze_pitch, gaze_yaw = model(images)
|
| 243 |
+
|
| 244 |
+
# Binned predictions
|
| 245 |
+
_, pitch_bpred = torch.max(gaze_pitch.data, 1)
|
| 246 |
+
_, yaw_bpred = torch.max(gaze_yaw.data, 1)
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
# Continuous predictions
|
| 250 |
+
pitch_predicted = softmax(gaze_pitch)
|
| 251 |
+
yaw_predicted = softmax(gaze_yaw)
|
| 252 |
+
|
| 253 |
+
# mapping from binned (0 to 28) to angels (-42 to 42)
|
| 254 |
+
pitch_predicted = \
|
| 255 |
+
torch.sum(pitch_predicted * idx_tensor, 1).cpu() * 3 - 42
|
| 256 |
+
yaw_predicted = \
|
| 257 |
+
torch.sum(yaw_predicted * idx_tensor, 1).cpu() * 3 - 42
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
pitch_predicted = pitch_predicted*np.pi/180
|
| 261 |
+
yaw_predicted = yaw_predicted*np.pi/180
|
| 262 |
+
|
| 263 |
+
for p,y,pl,yl in zip(pitch_predicted, yaw_predicted, label_pitch, label_yaw):
|
| 264 |
+
avg_error += angular(gazeto3d([p,y]), gazeto3d([pl,yl]))
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
x = ''.join(filter(lambda i: i.isdigit(), epochs))
|
| 268 |
+
epoch_list.append(x)
|
| 269 |
+
avg_MAE.append(avg_error/ total)
|
| 270 |
+
loger = f"[{epochs}---{args.dataset}] Total Num:{total},MAE:{avg_error/total} \n"
|
| 271 |
+
outfile.write(loger)
|
| 272 |
+
print(loger)
|
| 273 |
+
|
| 274 |
+
fig = plt.figure(figsize=(14, 8))
|
| 275 |
+
plt.xlabel('epoch')
|
| 276 |
+
plt.ylabel('avg')
|
| 277 |
+
plt.title('Gaze angular error')
|
| 278 |
+
plt.legend()
|
| 279 |
+
plt.plot(epoch_list, avg_MAE, color='k', label='mae')
|
| 280 |
+
fig.savefig(os.path.join(evalpath, os.path.join("fold"+str(fold), data_set+".png")), format='png')
|
| 281 |
+
# plt.show()
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
|
models/L2CS-Net/train.py
ADDED
|
@@ -0,0 +1,384 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import argparse
|
| 3 |
+
import time
|
| 4 |
+
|
| 5 |
+
import torch.utils.model_zoo as model_zoo
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from torch.autograd import Variable
|
| 9 |
+
from torch.utils.data import DataLoader
|
| 10 |
+
from torchvision import transforms
|
| 11 |
+
import torch.backends.cudnn as cudnn
|
| 12 |
+
import torchvision
|
| 13 |
+
|
| 14 |
+
from l2cs import L2CS, select_device, Gaze360, Mpiigaze
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def parse_args():
|
| 18 |
+
"""Parse input arguments."""
|
| 19 |
+
parser = argparse.ArgumentParser(description='Gaze estimation using L2CSNet.')
|
| 20 |
+
# Gaze360
|
| 21 |
+
parser.add_argument(
|
| 22 |
+
'--gaze360image_dir', dest='gaze360image_dir', help='Directory path for gaze images.',
|
| 23 |
+
default='datasets/Gaze360/Image', type=str)
|
| 24 |
+
parser.add_argument(
|
| 25 |
+
'--gaze360label_dir', dest='gaze360label_dir', help='Directory path for gaze labels.',
|
| 26 |
+
default='datasets/Gaze360/Label/train.label', type=str)
|
| 27 |
+
# mpiigaze
|
| 28 |
+
parser.add_argument(
|
| 29 |
+
'--gazeMpiimage_dir', dest='gazeMpiimage_dir', help='Directory path for gaze images.',
|
| 30 |
+
default='datasets/MPIIFaceGaze/Image', type=str)
|
| 31 |
+
parser.add_argument(
|
| 32 |
+
'--gazeMpiilabel_dir', dest='gazeMpiilabel_dir', help='Directory path for gaze labels.',
|
| 33 |
+
default='datasets/MPIIFaceGaze/Label', type=str)
|
| 34 |
+
|
| 35 |
+
# Important args -------------------------------------------------------------------------------------------------------
|
| 36 |
+
# ----------------------------------------------------------------------------------------------------------------------
|
| 37 |
+
parser.add_argument(
|
| 38 |
+
'--dataset', dest='dataset', help='mpiigaze, rtgene, gaze360, ethgaze',
|
| 39 |
+
default= "gaze360", type=str)
|
| 40 |
+
parser.add_argument(
|
| 41 |
+
'--output', dest='output', help='Path of output models.',
|
| 42 |
+
default='output/snapshots/', type=str)
|
| 43 |
+
parser.add_argument(
|
| 44 |
+
'--snapshot', dest='snapshot', help='Path of model snapshot.',
|
| 45 |
+
default='', type=str)
|
| 46 |
+
parser.add_argument(
|
| 47 |
+
'--gpu', dest='gpu_id', help='GPU device id to use [0] or multiple 0,1,2,3',
|
| 48 |
+
default='0', type=str)
|
| 49 |
+
parser.add_argument(
|
| 50 |
+
'--num_epochs', dest='num_epochs', help='Maximum number of training epochs.',
|
| 51 |
+
default=60, type=int)
|
| 52 |
+
parser.add_argument(
|
| 53 |
+
'--batch_size', dest='batch_size', help='Batch size.',
|
| 54 |
+
default=1, type=int)
|
| 55 |
+
parser.add_argument(
|
| 56 |
+
'--arch', dest='arch', help='Network architecture, can be: ResNet18, ResNet34, [ResNet50], ''ResNet101, ResNet152, Squeezenet_1_0, Squeezenet_1_1, MobileNetV2',
|
| 57 |
+
default='ResNet50', type=str)
|
| 58 |
+
parser.add_argument(
|
| 59 |
+
'--alpha', dest='alpha', help='Regression loss coefficient.',
|
| 60 |
+
default=1, type=float)
|
| 61 |
+
parser.add_argument(
|
| 62 |
+
'--lr', dest='lr', help='Base learning rate.',
|
| 63 |
+
default=0.00001, type=float)
|
| 64 |
+
# ---------------------------------------------------------------------------------------------------------------------
|
| 65 |
+
# Important args ------------------------------------------------------------------------------------------------------
|
| 66 |
+
args = parser.parse_args()
|
| 67 |
+
return args
|
| 68 |
+
|
| 69 |
+
def get_ignored_params(model):
|
| 70 |
+
# Generator function that yields ignored params.
|
| 71 |
+
b = [model.conv1, model.bn1, model.fc_finetune]
|
| 72 |
+
for i in range(len(b)):
|
| 73 |
+
for module_name, module in b[i].named_modules():
|
| 74 |
+
if 'bn' in module_name:
|
| 75 |
+
module.eval()
|
| 76 |
+
for name, param in module.named_parameters():
|
| 77 |
+
yield param
|
| 78 |
+
|
| 79 |
+
def get_non_ignored_params(model):
|
| 80 |
+
# Generator function that yields params that will be optimized.
|
| 81 |
+
b = [model.layer1, model.layer2, model.layer3, model.layer4]
|
| 82 |
+
for i in range(len(b)):
|
| 83 |
+
for module_name, module in b[i].named_modules():
|
| 84 |
+
if 'bn' in module_name:
|
| 85 |
+
module.eval()
|
| 86 |
+
for name, param in module.named_parameters():
|
| 87 |
+
yield param
|
| 88 |
+
|
| 89 |
+
def get_fc_params(model):
|
| 90 |
+
# Generator function that yields fc layer params.
|
| 91 |
+
b = [model.fc_yaw_gaze, model.fc_pitch_gaze]
|
| 92 |
+
for i in range(len(b)):
|
| 93 |
+
for module_name, module in b[i].named_modules():
|
| 94 |
+
for name, param in module.named_parameters():
|
| 95 |
+
yield param
|
| 96 |
+
|
| 97 |
+
def load_filtered_state_dict(model, snapshot):
|
| 98 |
+
# By user apaszke from discuss.pytorch.org
|
| 99 |
+
model_dict = model.state_dict()
|
| 100 |
+
snapshot = {k: v for k, v in snapshot.items() if k in model_dict}
|
| 101 |
+
model_dict.update(snapshot)
|
| 102 |
+
model.load_state_dict(model_dict)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def getArch_weights(arch, bins):
|
| 106 |
+
if arch == 'ResNet18':
|
| 107 |
+
model = L2CS(torchvision.models.resnet.BasicBlock, [2, 2, 2, 2], bins)
|
| 108 |
+
pre_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
|
| 109 |
+
elif arch == 'ResNet34':
|
| 110 |
+
model = L2CS(torchvision.models.resnet.BasicBlock, [3, 4, 6, 3], bins)
|
| 111 |
+
pre_url = 'https://download.pytorch.org/models/resnet34-333f7ec4.pth'
|
| 112 |
+
elif arch == 'ResNet101':
|
| 113 |
+
model = L2CS(torchvision.models.resnet.Bottleneck, [3, 4, 23, 3], bins)
|
| 114 |
+
pre_url = 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth'
|
| 115 |
+
elif arch == 'ResNet152':
|
| 116 |
+
model = L2CS(torchvision.models.resnet.Bottleneck,[3, 8, 36, 3], bins)
|
| 117 |
+
pre_url = 'https://download.pytorch.org/models/resnet152-b121ed2d.pth'
|
| 118 |
+
else:
|
| 119 |
+
if arch != 'ResNet50':
|
| 120 |
+
print('Invalid value for architecture is passed! '
|
| 121 |
+
'The default value of ResNet50 will be used instead!')
|
| 122 |
+
model = L2CS(torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], bins)
|
| 123 |
+
pre_url = 'https://download.pytorch.org/models/resnet50-19c8e357.pth'
|
| 124 |
+
|
| 125 |
+
return model, pre_url
|
| 126 |
+
|
| 127 |
+
if __name__ == '__main__':
|
| 128 |
+
args = parse_args()
|
| 129 |
+
cudnn.enabled = True
|
| 130 |
+
num_epochs = args.num_epochs
|
| 131 |
+
batch_size = args.batch_size
|
| 132 |
+
gpu = select_device(args.gpu_id, batch_size=args.batch_size)
|
| 133 |
+
data_set=args.dataset
|
| 134 |
+
alpha = args.alpha
|
| 135 |
+
output=args.output
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
transformations = transforms.Compose([
|
| 139 |
+
transforms.Resize(448),
|
| 140 |
+
transforms.ToTensor(),
|
| 141 |
+
transforms.Normalize(
|
| 142 |
+
mean=[0.485, 0.456, 0.406],
|
| 143 |
+
std=[0.229, 0.224, 0.225]
|
| 144 |
+
)
|
| 145 |
+
])
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
if data_set=="gaze360":
|
| 150 |
+
model, pre_url = getArch_weights(args.arch, 90)
|
| 151 |
+
if args.snapshot == '':
|
| 152 |
+
load_filtered_state_dict(model, model_zoo.load_url(pre_url))
|
| 153 |
+
else:
|
| 154 |
+
saved_state_dict = torch.load(args.snapshot)
|
| 155 |
+
model.load_state_dict(saved_state_dict)
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
model.cuda(gpu)
|
| 159 |
+
dataset=Gaze360(args.gaze360label_dir, args.gaze360image_dir, transformations, 180, 4)
|
| 160 |
+
print('Loading data.')
|
| 161 |
+
train_loader_gaze = DataLoader(
|
| 162 |
+
dataset=dataset,
|
| 163 |
+
batch_size=int(batch_size),
|
| 164 |
+
shuffle=True,
|
| 165 |
+
num_workers=0,
|
| 166 |
+
pin_memory=True)
|
| 167 |
+
torch.backends.cudnn.benchmark = True
|
| 168 |
+
|
| 169 |
+
summary_name = '{}_{}'.format('L2CS-gaze360-', int(time.time()))
|
| 170 |
+
output=os.path.join(output, summary_name)
|
| 171 |
+
if not os.path.exists(output):
|
| 172 |
+
os.makedirs(output)
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
criterion = nn.CrossEntropyLoss().cuda(gpu)
|
| 176 |
+
reg_criterion = nn.MSELoss().cuda(gpu)
|
| 177 |
+
softmax = nn.Softmax(dim=1).cuda(gpu)
|
| 178 |
+
idx_tensor = [idx for idx in range(90)]
|
| 179 |
+
idx_tensor = Variable(torch.FloatTensor(idx_tensor)).cuda(gpu)
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
# Optimizer gaze
|
| 183 |
+
optimizer_gaze = torch.optim.Adam([
|
| 184 |
+
{'params': get_ignored_params(model), 'lr': 0},
|
| 185 |
+
{'params': get_non_ignored_params(model), 'lr': args.lr},
|
| 186 |
+
{'params': get_fc_params(model), 'lr': args.lr}
|
| 187 |
+
], args.lr)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
configuration = f"\ntrain configuration, gpu_id={args.gpu_id}, batch_size={batch_size}, model_arch={args.arch}\nStart testing dataset={data_set}, loader={len(train_loader_gaze)}------------------------- \n"
|
| 191 |
+
print(configuration)
|
| 192 |
+
for epoch in range(num_epochs):
|
| 193 |
+
sum_loss_pitch_gaze = sum_loss_yaw_gaze = iter_gaze = 0
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
for i, (images_gaze, labels_gaze, cont_labels_gaze,name) in enumerate(train_loader_gaze):
|
| 197 |
+
images_gaze = Variable(images_gaze).cuda(gpu)
|
| 198 |
+
|
| 199 |
+
# Binned labels
|
| 200 |
+
label_pitch_gaze = Variable(labels_gaze[:, 0]).cuda(gpu)
|
| 201 |
+
label_yaw_gaze = Variable(labels_gaze[:, 1]).cuda(gpu)
|
| 202 |
+
|
| 203 |
+
# Continuous labels
|
| 204 |
+
label_pitch_cont_gaze = Variable(cont_labels_gaze[:, 0]).cuda(gpu)
|
| 205 |
+
label_yaw_cont_gaze = Variable(cont_labels_gaze[:, 1]).cuda(gpu)
|
| 206 |
+
|
| 207 |
+
pitch, yaw = model(images_gaze)
|
| 208 |
+
|
| 209 |
+
# Cross entropy loss
|
| 210 |
+
loss_pitch_gaze = criterion(pitch, label_pitch_gaze)
|
| 211 |
+
loss_yaw_gaze = criterion(yaw, label_yaw_gaze)
|
| 212 |
+
|
| 213 |
+
# MSE loss
|
| 214 |
+
pitch_predicted = softmax(pitch)
|
| 215 |
+
yaw_predicted = softmax(yaw)
|
| 216 |
+
|
| 217 |
+
pitch_predicted = \
|
| 218 |
+
torch.sum(pitch_predicted * idx_tensor, 1) * 4 - 180
|
| 219 |
+
yaw_predicted = \
|
| 220 |
+
torch.sum(yaw_predicted * idx_tensor, 1) * 4 - 180
|
| 221 |
+
|
| 222 |
+
loss_reg_pitch = reg_criterion(
|
| 223 |
+
pitch_predicted, label_pitch_cont_gaze)
|
| 224 |
+
loss_reg_yaw = reg_criterion(
|
| 225 |
+
yaw_predicted, label_yaw_cont_gaze)
|
| 226 |
+
|
| 227 |
+
# Total loss
|
| 228 |
+
loss_pitch_gaze += alpha * loss_reg_pitch
|
| 229 |
+
loss_yaw_gaze += alpha * loss_reg_yaw
|
| 230 |
+
|
| 231 |
+
sum_loss_pitch_gaze += loss_pitch_gaze
|
| 232 |
+
sum_loss_yaw_gaze += loss_yaw_gaze
|
| 233 |
+
|
| 234 |
+
loss_seq = [loss_pitch_gaze, loss_yaw_gaze]
|
| 235 |
+
grad_seq = [torch.tensor(1.0).cuda(gpu) for _ in range(len(loss_seq))]
|
| 236 |
+
optimizer_gaze.zero_grad(set_to_none=True)
|
| 237 |
+
torch.autograd.backward(loss_seq, grad_seq)
|
| 238 |
+
optimizer_gaze.step()
|
| 239 |
+
# scheduler.step()
|
| 240 |
+
|
| 241 |
+
iter_gaze += 1
|
| 242 |
+
|
| 243 |
+
if (i+1) % 100 == 0:
|
| 244 |
+
print('Epoch [%d/%d], Iter [%d/%d] Losses: '
|
| 245 |
+
'Gaze Yaw %.4f,Gaze Pitch %.4f' % (
|
| 246 |
+
epoch+1,
|
| 247 |
+
num_epochs,
|
| 248 |
+
i+1,
|
| 249 |
+
len(dataset)//batch_size,
|
| 250 |
+
sum_loss_pitch_gaze/iter_gaze,
|
| 251 |
+
sum_loss_yaw_gaze/iter_gaze
|
| 252 |
+
)
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
if epoch % 1 == 0 and epoch < num_epochs:
|
| 257 |
+
print('Taking snapshot...',
|
| 258 |
+
torch.save(model.state_dict(),
|
| 259 |
+
output +'/'+
|
| 260 |
+
'_epoch_' + str(epoch+1) + '.pkl')
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
elif data_set=="mpiigaze":
|
| 266 |
+
folder = os.listdir(args.gazeMpiilabel_dir)
|
| 267 |
+
folder.sort()
|
| 268 |
+
testlabelpathombined = [os.path.join(args.gazeMpiilabel_dir, j) for j in folder]
|
| 269 |
+
for fold in range(15):
|
| 270 |
+
model, pre_url = getArch_weights(args.arch, 28)
|
| 271 |
+
load_filtered_state_dict(model, model_zoo.load_url(pre_url))
|
| 272 |
+
model = nn.DataParallel(model)
|
| 273 |
+
model.to(gpu)
|
| 274 |
+
print('Loading data.')
|
| 275 |
+
dataset=Mpiigaze(testlabelpathombined,args.gazeMpiimage_dir, transformations, True, fold)
|
| 276 |
+
train_loader_gaze = DataLoader(
|
| 277 |
+
dataset=dataset,
|
| 278 |
+
batch_size=int(batch_size),
|
| 279 |
+
shuffle=True,
|
| 280 |
+
num_workers=4,
|
| 281 |
+
pin_memory=True)
|
| 282 |
+
torch.backends.cudnn.benchmark = True
|
| 283 |
+
|
| 284 |
+
summary_name = '{}_{}'.format('L2CS-mpiigaze', int(time.time()))
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
if not os.path.exists(os.path.join(output+'/{}'.format(summary_name),'fold' + str(fold))):
|
| 288 |
+
os.makedirs(os.path.join(output+'/{}'.format(summary_name),'fold' + str(fold)))
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
criterion = nn.CrossEntropyLoss().cuda(gpu)
|
| 292 |
+
reg_criterion = nn.MSELoss().cuda(gpu)
|
| 293 |
+
softmax = nn.Softmax(dim=1).cuda(gpu)
|
| 294 |
+
idx_tensor = [idx for idx in range(28)]
|
| 295 |
+
idx_tensor = Variable(torch.FloatTensor(idx_tensor)).cuda(gpu)
|
| 296 |
+
|
| 297 |
+
# Optimizer gaze
|
| 298 |
+
optimizer_gaze = torch.optim.Adam([
|
| 299 |
+
{'params': get_ignored_params(model, args.arch), 'lr': 0},
|
| 300 |
+
{'params': get_non_ignored_params(model, args.arch), 'lr': args.lr},
|
| 301 |
+
{'params': get_fc_params(model, args.arch), 'lr': args.lr}
|
| 302 |
+
], args.lr)
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
configuration = f"\ntrain configuration, gpu_id={args.gpu_id}, batch_size={batch_size}, model_arch={args.arch}\n Start training dataset={data_set}, loader={len(train_loader_gaze)}, fold={fold}--------------\n"
|
| 307 |
+
print(configuration)
|
| 308 |
+
for epoch in range(num_epochs):
|
| 309 |
+
sum_loss_pitch_gaze = sum_loss_yaw_gaze = iter_gaze = 0
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
for i, (images_gaze, labels_gaze, cont_labels_gaze,name) in enumerate(train_loader_gaze):
|
| 313 |
+
images_gaze = Variable(images_gaze).cuda(gpu)
|
| 314 |
+
|
| 315 |
+
# Binned labels
|
| 316 |
+
label_pitch_gaze = Variable(labels_gaze[:, 0]).cuda(gpu)
|
| 317 |
+
label_yaw_gaze = Variable(labels_gaze[:, 1]).cuda(gpu)
|
| 318 |
+
|
| 319 |
+
# Continuous labels
|
| 320 |
+
label_pitch_cont_gaze = Variable(cont_labels_gaze[:, 0]).cuda(gpu)
|
| 321 |
+
label_yaw_cont_gaze = Variable(cont_labels_gaze[:, 1]).cuda(gpu)
|
| 322 |
+
|
| 323 |
+
pitch, yaw = model(images_gaze)
|
| 324 |
+
|
| 325 |
+
# Cross entropy loss
|
| 326 |
+
loss_pitch_gaze = criterion(pitch, label_pitch_gaze)
|
| 327 |
+
loss_yaw_gaze = criterion(yaw, label_yaw_gaze)
|
| 328 |
+
|
| 329 |
+
# MSE loss
|
| 330 |
+
pitch_predicted = softmax(pitch)
|
| 331 |
+
yaw_predicted = softmax(yaw)
|
| 332 |
+
|
| 333 |
+
pitch_predicted = \
|
| 334 |
+
torch.sum(pitch_predicted * idx_tensor, 1) * 3 - 42
|
| 335 |
+
yaw_predicted = \
|
| 336 |
+
torch.sum(yaw_predicted * idx_tensor, 1) * 3 - 42
|
| 337 |
+
|
| 338 |
+
loss_reg_pitch = reg_criterion(
|
| 339 |
+
pitch_predicted, label_pitch_cont_gaze)
|
| 340 |
+
loss_reg_yaw = reg_criterion(
|
| 341 |
+
yaw_predicted, label_yaw_cont_gaze)
|
| 342 |
+
|
| 343 |
+
# Total loss
|
| 344 |
+
loss_pitch_gaze += alpha * loss_reg_pitch
|
| 345 |
+
loss_yaw_gaze += alpha * loss_reg_yaw
|
| 346 |
+
|
| 347 |
+
sum_loss_pitch_gaze += loss_pitch_gaze
|
| 348 |
+
sum_loss_yaw_gaze += loss_yaw_gaze
|
| 349 |
+
|
| 350 |
+
loss_seq = [loss_pitch_gaze, loss_yaw_gaze]
|
| 351 |
+
grad_seq = \
|
| 352 |
+
[torch.tensor(1.0).cuda(gpu) for _ in range(len(loss_seq))]
|
| 353 |
+
|
| 354 |
+
optimizer_gaze.zero_grad(set_to_none=True)
|
| 355 |
+
torch.autograd.backward(loss_seq, grad_seq)
|
| 356 |
+
optimizer_gaze.step()
|
| 357 |
+
|
| 358 |
+
iter_gaze += 1
|
| 359 |
+
|
| 360 |
+
if (i+1) % 100 == 0:
|
| 361 |
+
print('Epoch [%d/%d], Iter [%d/%d] Losses: '
|
| 362 |
+
'Gaze Yaw %.4f,Gaze Pitch %.4f' % (
|
| 363 |
+
epoch+1,
|
| 364 |
+
num_epochs,
|
| 365 |
+
i+1,
|
| 366 |
+
len(dataset)//batch_size,
|
| 367 |
+
sum_loss_pitch_gaze/iter_gaze,
|
| 368 |
+
sum_loss_yaw_gaze/iter_gaze
|
| 369 |
+
)
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
# Save models at numbered epochs.
|
| 375 |
+
if epoch % 1 == 0 and epoch < num_epochs:
|
| 376 |
+
print('Taking snapshot...',
|
| 377 |
+
torch.save(model.state_dict(),
|
| 378 |
+
output+'/fold' + str(fold) +'/'+
|
| 379 |
+
'_epoch_' + str(epoch+1) + '.pkl')
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
|
models/gaze_calibration.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 9-point gaze calibration for L2CS-Net
|
| 2 |
+
# Maps raw gaze angles -> normalised screen coords via polynomial least-squares.
|
| 3 |
+
# Centre point is the bias reference (subtracted from all readings).
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
from dataclasses import dataclass, field
|
| 7 |
+
|
| 8 |
+
# 3x3 grid, centre first (bias ref), then row by row
|
| 9 |
+
DEFAULT_TARGETS = [
|
| 10 |
+
(0.5, 0.5),
|
| 11 |
+
(0.15, 0.15), (0.50, 0.15), (0.85, 0.15),
|
| 12 |
+
(0.15, 0.50), (0.85, 0.50),
|
| 13 |
+
(0.15, 0.85), (0.50, 0.85), (0.85, 0.85),
|
| 14 |
+
]
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@dataclass
|
| 18 |
+
class _PointSamples:
|
| 19 |
+
target_x: float
|
| 20 |
+
target_y: float
|
| 21 |
+
yaws: list = field(default_factory=list)
|
| 22 |
+
pitches: list = field(default_factory=list)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def _iqr_filter(values):
|
| 26 |
+
if len(values) < 4:
|
| 27 |
+
return values
|
| 28 |
+
arr = np.array(values)
|
| 29 |
+
q1, q3 = np.percentile(arr, [25, 75])
|
| 30 |
+
iqr = q3 - q1
|
| 31 |
+
lo, hi = q1 - 1.5 * iqr, q3 + 1.5 * iqr
|
| 32 |
+
return arr[(arr >= lo) & (arr <= hi)].tolist()
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class GazeCalibration:
|
| 36 |
+
|
| 37 |
+
def __init__(self, targets=None):
|
| 38 |
+
self._targets = targets or list(DEFAULT_TARGETS)
|
| 39 |
+
self._points = [_PointSamples(tx, ty) for tx, ty in self._targets]
|
| 40 |
+
self._current_idx = 0
|
| 41 |
+
self._fitted = False
|
| 42 |
+
self._W = None # (6, 2) polynomial weights
|
| 43 |
+
self._yaw_bias = 0.0
|
| 44 |
+
self._pitch_bias = 0.0
|
| 45 |
+
|
| 46 |
+
@property
|
| 47 |
+
def num_points(self):
|
| 48 |
+
return len(self._targets)
|
| 49 |
+
|
| 50 |
+
@property
|
| 51 |
+
def current_index(self):
|
| 52 |
+
return self._current_idx
|
| 53 |
+
|
| 54 |
+
@property
|
| 55 |
+
def current_target(self):
|
| 56 |
+
if self._current_idx < len(self._targets):
|
| 57 |
+
return self._targets[self._current_idx]
|
| 58 |
+
return self._targets[-1]
|
| 59 |
+
|
| 60 |
+
@property
|
| 61 |
+
def is_complete(self):
|
| 62 |
+
return self._current_idx >= len(self._targets)
|
| 63 |
+
|
| 64 |
+
@property
|
| 65 |
+
def is_fitted(self):
|
| 66 |
+
return self._fitted
|
| 67 |
+
|
| 68 |
+
def collect_sample(self, yaw_rad, pitch_rad):
|
| 69 |
+
if self._current_idx >= len(self._points):
|
| 70 |
+
return
|
| 71 |
+
pt = self._points[self._current_idx]
|
| 72 |
+
pt.yaws.append(float(yaw_rad))
|
| 73 |
+
pt.pitches.append(float(pitch_rad))
|
| 74 |
+
|
| 75 |
+
def advance(self):
|
| 76 |
+
self._current_idx += 1
|
| 77 |
+
return self._current_idx < len(self._targets)
|
| 78 |
+
|
| 79 |
+
@staticmethod
|
| 80 |
+
def _poly_features(yaw, pitch):
|
| 81 |
+
# [yaw^2, pitch^2, yaw*pitch, yaw, pitch, 1]
|
| 82 |
+
return np.array([yaw**2, pitch**2, yaw * pitch, yaw, pitch, 1.0],
|
| 83 |
+
dtype=np.float64)
|
| 84 |
+
|
| 85 |
+
def fit(self):
|
| 86 |
+
# bias from centre point (index 0)
|
| 87 |
+
center = self._points[0]
|
| 88 |
+
center_yaws = _iqr_filter(center.yaws)
|
| 89 |
+
center_pitches = _iqr_filter(center.pitches)
|
| 90 |
+
if len(center_yaws) < 2 or len(center_pitches) < 2:
|
| 91 |
+
return False
|
| 92 |
+
self._yaw_bias = float(np.median(center_yaws))
|
| 93 |
+
self._pitch_bias = float(np.median(center_pitches))
|
| 94 |
+
|
| 95 |
+
rows_A, rows_B = [], []
|
| 96 |
+
for pt in self._points:
|
| 97 |
+
clean_yaws = _iqr_filter(pt.yaws)
|
| 98 |
+
clean_pitches = _iqr_filter(pt.pitches)
|
| 99 |
+
if len(clean_yaws) < 2 or len(clean_pitches) < 2:
|
| 100 |
+
continue
|
| 101 |
+
med_yaw = float(np.median(clean_yaws)) - self._yaw_bias
|
| 102 |
+
med_pitch = float(np.median(clean_pitches)) - self._pitch_bias
|
| 103 |
+
rows_A.append(self._poly_features(med_yaw, med_pitch))
|
| 104 |
+
rows_B.append([pt.target_x, pt.target_y])
|
| 105 |
+
|
| 106 |
+
if len(rows_A) < 5:
|
| 107 |
+
return False
|
| 108 |
+
|
| 109 |
+
A = np.array(rows_A, dtype=np.float64)
|
| 110 |
+
B = np.array(rows_B, dtype=np.float64)
|
| 111 |
+
try:
|
| 112 |
+
W, _, _, _ = np.linalg.lstsq(A, B, rcond=None)
|
| 113 |
+
self._W = W
|
| 114 |
+
self._fitted = True
|
| 115 |
+
return True
|
| 116 |
+
except np.linalg.LinAlgError:
|
| 117 |
+
return False
|
| 118 |
+
|
| 119 |
+
def predict(self, yaw_rad, pitch_rad):
|
| 120 |
+
if not self._fitted or self._W is None:
|
| 121 |
+
return 0.5, 0.5
|
| 122 |
+
feat = self._poly_features(yaw_rad - self._yaw_bias, pitch_rad - self._pitch_bias)
|
| 123 |
+
xy = feat @ self._W
|
| 124 |
+
return float(np.clip(xy[0], 0, 1)), float(np.clip(xy[1], 0, 1))
|
| 125 |
+
|
| 126 |
+
def to_dict(self):
|
| 127 |
+
return {
|
| 128 |
+
"targets": self._targets,
|
| 129 |
+
"fitted": self._fitted,
|
| 130 |
+
"current_index": self._current_idx,
|
| 131 |
+
"W": self._W.tolist() if self._W is not None else None,
|
| 132 |
+
"yaw_bias": self._yaw_bias,
|
| 133 |
+
"pitch_bias": self._pitch_bias,
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
@classmethod
|
| 137 |
+
def from_dict(cls, d):
|
| 138 |
+
cal = cls(targets=d.get("targets", DEFAULT_TARGETS))
|
| 139 |
+
cal._fitted = d.get("fitted", False)
|
| 140 |
+
cal._current_idx = d.get("current_index", 0)
|
| 141 |
+
cal._yaw_bias = d.get("yaw_bias", 0.0)
|
| 142 |
+
cal._pitch_bias = d.get("pitch_bias", 0.0)
|
| 143 |
+
w = d.get("W")
|
| 144 |
+
if w is not None:
|
| 145 |
+
cal._W = np.array(w, dtype=np.float64)
|
| 146 |
+
return cal
|
models/gaze_eye_fusion.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Fuses calibrated gaze position with eye openness (EAR) for focus detection.
|
| 2 |
+
# Takes L2CS gaze angles + MediaPipe landmarks, outputs screen coords + focus decision.
|
| 3 |
+
|
| 4 |
+
import math
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
from .gaze_calibration import GazeCalibration
|
| 8 |
+
from .eye_scorer import compute_avg_ear
|
| 9 |
+
|
| 10 |
+
_EAR_BLINK = 0.18
|
| 11 |
+
_ON_SCREEN_MARGIN = 0.08
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class GazeEyeFusion:
|
| 15 |
+
|
| 16 |
+
def __init__(self, calibration, ear_weight=0.3, gaze_weight=0.7, focus_threshold=0.52):
|
| 17 |
+
if not calibration.is_fitted:
|
| 18 |
+
raise ValueError("Calibration must be fitted first")
|
| 19 |
+
self._cal = calibration
|
| 20 |
+
self._ear_w = ear_weight
|
| 21 |
+
self._gaze_w = gaze_weight
|
| 22 |
+
self._threshold = focus_threshold
|
| 23 |
+
self._smooth_x = 0.5
|
| 24 |
+
self._smooth_y = 0.5
|
| 25 |
+
self._alpha = 0.5
|
| 26 |
+
|
| 27 |
+
def update(self, yaw_rad, pitch_rad, landmarks):
|
| 28 |
+
gx, gy = self._cal.predict(yaw_rad, pitch_rad)
|
| 29 |
+
|
| 30 |
+
# EMA smooth the gaze position
|
| 31 |
+
self._smooth_x += self._alpha * (gx - self._smooth_x)
|
| 32 |
+
self._smooth_y += self._alpha * (gy - self._smooth_y)
|
| 33 |
+
gx, gy = self._smooth_x, self._smooth_y
|
| 34 |
+
|
| 35 |
+
on_screen = (
|
| 36 |
+
-_ON_SCREEN_MARGIN <= gx <= 1.0 + _ON_SCREEN_MARGIN and
|
| 37 |
+
-_ON_SCREEN_MARGIN <= gy <= 1.0 + _ON_SCREEN_MARGIN
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
ear = None
|
| 41 |
+
ear_score = 1.0
|
| 42 |
+
if landmarks is not None:
|
| 43 |
+
ear = compute_avg_ear(landmarks)
|
| 44 |
+
ear_score = 0.0 if ear < _EAR_BLINK else min(ear / 0.30, 1.0)
|
| 45 |
+
|
| 46 |
+
# penalise gaze near screen edges
|
| 47 |
+
gaze_score = 1.0 if on_screen else 0.0
|
| 48 |
+
if on_screen:
|
| 49 |
+
dx = max(0.0, abs(gx - 0.5) - 0.3)
|
| 50 |
+
dy = max(0.0, abs(gy - 0.5) - 0.3)
|
| 51 |
+
gaze_score = max(0.0, 1.0 - math.sqrt(dx**2 + dy**2) * 5.0)
|
| 52 |
+
|
| 53 |
+
score = float(np.clip(self._gaze_w * gaze_score + self._ear_w * ear_score, 0, 1))
|
| 54 |
+
|
| 55 |
+
return {
|
| 56 |
+
"gaze_x": round(float(gx), 4),
|
| 57 |
+
"gaze_y": round(float(gy), 4),
|
| 58 |
+
"on_screen": on_screen,
|
| 59 |
+
"ear": round(ear, 4) if ear is not None else None,
|
| 60 |
+
"focus_score": round(score, 4),
|
| 61 |
+
"focused": score >= self._threshold,
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
def reset(self):
|
| 65 |
+
self._smooth_x = 0.5
|
| 66 |
+
self._smooth_y = 0.5
|
requirements.txt
CHANGED
|
@@ -14,3 +14,7 @@ aiosqlite>=0.19.0
|
|
| 14 |
pydantic>=2.0.0
|
| 15 |
xgboost>=2.0.0
|
| 16 |
clearml>=2.0.2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
pydantic>=2.0.0
|
| 15 |
xgboost>=2.0.0
|
| 16 |
clearml>=2.0.2
|
| 17 |
+
torch>=1.10.1
|
| 18 |
+
torchvision>=0.11.2
|
| 19 |
+
face_detection @ git+https://github.com/elliottzheng/face-detection
|
| 20 |
+
gdown>=5.0.0
|
src/components/CalibrationOverlay.jsx
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import React, { useState, useEffect, useRef, useCallback } from 'react';
|
| 2 |
+
|
| 3 |
+
const COLLECT_MS = 2000;
|
| 4 |
+
const CENTER_MS = 3000; // centre point gets extra time (bias reference)
|
| 5 |
+
|
| 6 |
+
function CalibrationOverlay({ calibration, videoManager }) {
|
| 7 |
+
const [progress, setProgress] = useState(0);
|
| 8 |
+
const timerRef = useRef(null);
|
| 9 |
+
const startRef = useRef(null);
|
| 10 |
+
const overlayRef = useRef(null);
|
| 11 |
+
|
| 12 |
+
const enterFullscreen = useCallback(() => {
|
| 13 |
+
const el = overlayRef.current;
|
| 14 |
+
if (!el) return;
|
| 15 |
+
const req = el.requestFullscreen || el.webkitRequestFullscreen || el.msRequestFullscreen;
|
| 16 |
+
if (req) req.call(el).catch(() => {});
|
| 17 |
+
}, []);
|
| 18 |
+
|
| 19 |
+
const exitFullscreen = useCallback(() => {
|
| 20 |
+
if (document.fullscreenElement || document.webkitFullscreenElement) {
|
| 21 |
+
const exit = document.exitFullscreen || document.webkitExitFullscreen || document.msExitFullscreen;
|
| 22 |
+
if (exit) exit.call(document).catch(() => {});
|
| 23 |
+
}
|
| 24 |
+
}, []);
|
| 25 |
+
|
| 26 |
+
useEffect(() => {
|
| 27 |
+
if (calibration && calibration.active && !calibration.done) {
|
| 28 |
+
const t = setTimeout(enterFullscreen, 100);
|
| 29 |
+
return () => clearTimeout(t);
|
| 30 |
+
}
|
| 31 |
+
}, [calibration?.active]);
|
| 32 |
+
|
| 33 |
+
useEffect(() => {
|
| 34 |
+
if (!calibration || !calibration.active) exitFullscreen();
|
| 35 |
+
}, [calibration?.active]);
|
| 36 |
+
|
| 37 |
+
useEffect(() => {
|
| 38 |
+
if (!calibration || !calibration.collecting || calibration.done) {
|
| 39 |
+
setProgress(0);
|
| 40 |
+
if (timerRef.current) cancelAnimationFrame(timerRef.current);
|
| 41 |
+
return;
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
startRef.current = performance.now();
|
| 45 |
+
const duration = calibration.index === 0 ? CENTER_MS : COLLECT_MS;
|
| 46 |
+
|
| 47 |
+
const tick = () => {
|
| 48 |
+
const pct = Math.min((performance.now() - startRef.current) / duration, 1);
|
| 49 |
+
setProgress(pct);
|
| 50 |
+
if (pct >= 1) {
|
| 51 |
+
if (videoManager) videoManager.nextCalibrationPoint();
|
| 52 |
+
startRef.current = performance.now();
|
| 53 |
+
setProgress(0);
|
| 54 |
+
}
|
| 55 |
+
timerRef.current = requestAnimationFrame(tick);
|
| 56 |
+
};
|
| 57 |
+
timerRef.current = requestAnimationFrame(tick);
|
| 58 |
+
|
| 59 |
+
return () => { if (timerRef.current) cancelAnimationFrame(timerRef.current); };
|
| 60 |
+
}, [calibration?.index, calibration?.collecting, calibration?.done]);
|
| 61 |
+
|
| 62 |
+
const handleCancel = () => {
|
| 63 |
+
if (videoManager) videoManager.cancelCalibration();
|
| 64 |
+
exitFullscreen();
|
| 65 |
+
};
|
| 66 |
+
|
| 67 |
+
if (!calibration || !calibration.active) return null;
|
| 68 |
+
|
| 69 |
+
if (calibration.done) {
|
| 70 |
+
return (
|
| 71 |
+
<div ref={overlayRef} style={overlayStyle}>
|
| 72 |
+
<div style={messageBoxStyle}>
|
| 73 |
+
<h2 style={{ margin: '0 0 10px', color: calibration.success ? '#4ade80' : '#f87171' }}>
|
| 74 |
+
{calibration.success ? 'Calibration Complete' : 'Calibration Failed'}
|
| 75 |
+
</h2>
|
| 76 |
+
<p style={{ color: '#ccc', margin: 0 }}>
|
| 77 |
+
{calibration.success
|
| 78 |
+
? 'Gaze tracking is now active.'
|
| 79 |
+
: 'Not enough samples collected. Try again.'}
|
| 80 |
+
</p>
|
| 81 |
+
</div>
|
| 82 |
+
</div>
|
| 83 |
+
);
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
const [tx, ty] = calibration.target || [0.5, 0.5];
|
| 87 |
+
|
| 88 |
+
return (
|
| 89 |
+
<div ref={overlayRef} style={overlayStyle}>
|
| 90 |
+
<div style={{
|
| 91 |
+
position: 'absolute', top: '30px', left: '50%', transform: 'translateX(-50%)',
|
| 92 |
+
color: '#fff', fontSize: '16px', textAlign: 'center',
|
| 93 |
+
textShadow: '0 0 8px rgba(0,0,0,0.8)', pointerEvents: 'none',
|
| 94 |
+
}}>
|
| 95 |
+
<div style={{ fontWeight: 'bold', fontSize: '20px' }}>
|
| 96 |
+
Look at the dot ({calibration.index + 1}/{calibration.numPoints})
|
| 97 |
+
</div>
|
| 98 |
+
<div style={{ fontSize: '14px', color: '#aaa', marginTop: '6px' }}>
|
| 99 |
+
{calibration.index === 0
|
| 100 |
+
? 'Look at the center dot - this sets your baseline'
|
| 101 |
+
: 'Hold your gaze steady on the target'}
|
| 102 |
+
</div>
|
| 103 |
+
</div>
|
| 104 |
+
|
| 105 |
+
<div style={{
|
| 106 |
+
position: 'absolute', left: `${tx * 100}%`, top: `${ty * 100}%`,
|
| 107 |
+
transform: 'translate(-50%, -50%)',
|
| 108 |
+
}}>
|
| 109 |
+
<svg width="60" height="60" style={{ position: 'absolute', left: '-30px', top: '-30px' }}>
|
| 110 |
+
<circle cx="30" cy="30" r="24" fill="none" stroke="rgba(255,255,255,0.15)" strokeWidth="3" />
|
| 111 |
+
<circle cx="30" cy="30" r="24" fill="none" stroke="#4ade80" strokeWidth="3"
|
| 112 |
+
strokeDasharray={`${progress * 150.8} 150.8`} strokeLinecap="round"
|
| 113 |
+
transform="rotate(-90, 30, 30)" />
|
| 114 |
+
</svg>
|
| 115 |
+
<div style={{
|
| 116 |
+
width: '20px', height: '20px', borderRadius: '50%',
|
| 117 |
+
background: 'radial-gradient(circle, #fff 30%, #4ade80 100%)',
|
| 118 |
+
boxShadow: '0 0 20px rgba(74, 222, 128, 0.8)',
|
| 119 |
+
}} />
|
| 120 |
+
</div>
|
| 121 |
+
|
| 122 |
+
<button onClick={handleCancel} style={{
|
| 123 |
+
position: 'absolute', bottom: '40px', left: '50%', transform: 'translateX(-50%)',
|
| 124 |
+
padding: '10px 28px', background: 'rgba(255,255,255,0.1)',
|
| 125 |
+
border: '1px solid rgba(255,255,255,0.3)', color: '#fff',
|
| 126 |
+
borderRadius: '20px', cursor: 'pointer', fontSize: '14px',
|
| 127 |
+
}}>
|
| 128 |
+
Cancel Calibration
|
| 129 |
+
</button>
|
| 130 |
+
</div>
|
| 131 |
+
);
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
const overlayStyle = {
|
| 135 |
+
position: 'fixed', top: 0, left: 0, width: '100vw', height: '100vh',
|
| 136 |
+
background: 'rgba(0, 0, 0, 0.92)', zIndex: 10000,
|
| 137 |
+
display: 'flex', alignItems: 'center', justifyContent: 'center',
|
| 138 |
+
};
|
| 139 |
+
|
| 140 |
+
const messageBoxStyle = {
|
| 141 |
+
textAlign: 'center', padding: '30px 40px',
|
| 142 |
+
background: 'rgba(30, 30, 50, 0.9)', borderRadius: '16px',
|
| 143 |
+
border: '1px solid rgba(255,255,255,0.1)',
|
| 144 |
+
};
|
| 145 |
+
|
| 146 |
+
export default CalibrationOverlay;
|
src/components/FocusPageLocal.jsx
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
import React, { useState, useEffect, useRef } from 'react';
|
|
|
|
| 2 |
|
| 3 |
function FocusPageLocal({ videoManager, sessionResult, setSessionResult, isActive }) {
|
| 4 |
const [currentFrame, setCurrentFrame] = useState(15);
|
|
@@ -6,6 +7,9 @@ function FocusPageLocal({ videoManager, sessionResult, setSessionResult, isActiv
|
|
| 6 |
const [stats, setStats] = useState(null);
|
| 7 |
const [availableModels, setAvailableModels] = useState([]);
|
| 8 |
const [currentModel, setCurrentModel] = useState('mlp');
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
const localVideoRef = useRef(null);
|
| 11 |
const displayCanvasRef = useRef(null);
|
|
@@ -23,7 +27,6 @@ function FocusPageLocal({ videoManager, sessionResult, setSessionResult, isActiv
|
|
| 23 |
useEffect(() => {
|
| 24 |
if (!videoManager) return;
|
| 25 |
|
| 26 |
-
// 设置回调函数来更新时间轴
|
| 27 |
const originalOnStatusUpdate = videoManager.callbacks.onStatusUpdate;
|
| 28 |
videoManager.callbacks.onStatusUpdate = (isFocused) => {
|
| 29 |
setTimelineEvents(prev => {
|
|
@@ -34,7 +37,10 @@ function FocusPageLocal({ videoManager, sessionResult, setSessionResult, isActiv
|
|
| 34 |
if (originalOnStatusUpdate) originalOnStatusUpdate(isFocused);
|
| 35 |
};
|
| 36 |
|
| 37 |
-
|
|
|
|
|
|
|
|
|
|
| 38 |
const statsInterval = setInterval(() => {
|
| 39 |
if (videoManager && videoManager.getStats) {
|
| 40 |
setStats(videoManager.getStats());
|
|
@@ -44,6 +50,7 @@ function FocusPageLocal({ videoManager, sessionResult, setSessionResult, isActiv
|
|
| 44 |
return () => {
|
| 45 |
if (videoManager) {
|
| 46 |
videoManager.callbacks.onStatusUpdate = originalOnStatusUpdate;
|
|
|
|
| 47 |
}
|
| 48 |
clearInterval(statsInterval);
|
| 49 |
};
|
|
@@ -56,6 +63,8 @@ function FocusPageLocal({ videoManager, sessionResult, setSessionResult, isActiv
|
|
| 56 |
.then(data => {
|
| 57 |
if (data.available) setAvailableModels(data.available);
|
| 58 |
if (data.current) setCurrentModel(data.current);
|
|
|
|
|
|
|
| 59 |
})
|
| 60 |
.catch(err => console.error('Failed to fetch models:', err));
|
| 61 |
}, []);
|
|
@@ -70,12 +79,28 @@ function FocusPageLocal({ videoManager, sessionResult, setSessionResult, isActiv
|
|
| 70 |
const result = await res.json();
|
| 71 |
if (result.updated) {
|
| 72 |
setCurrentModel(modelName);
|
|
|
|
|
|
|
| 73 |
}
|
| 74 |
} catch (err) {
|
| 75 |
console.error('Failed to switch model:', err);
|
| 76 |
}
|
| 77 |
};
|
| 78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
const handleStart = async () => {
|
| 80 |
try {
|
| 81 |
if (videoManager) {
|
|
@@ -443,6 +468,44 @@ function FocusPageLocal({ videoManager, sessionResult, setSessionResult, isActiv
|
|
| 443 |
{name}
|
| 444 |
</button>
|
| 445 |
))}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 446 |
</section>
|
| 447 |
)}
|
| 448 |
|
|
@@ -513,6 +576,9 @@ function FocusPageLocal({ videoManager, sessionResult, setSessionResult, isActiv
|
|
| 513 |
onChange={(e) => handleFrameChange(e.target.value)}
|
| 514 |
/>
|
| 515 |
</section>
|
|
|
|
|
|
|
|
|
|
| 516 |
</main>
|
| 517 |
);
|
| 518 |
}
|
|
|
|
| 1 |
import React, { useState, useEffect, useRef } from 'react';
|
| 2 |
+
import CalibrationOverlay from './CalibrationOverlay';
|
| 3 |
|
| 4 |
function FocusPageLocal({ videoManager, sessionResult, setSessionResult, isActive }) {
|
| 5 |
const [currentFrame, setCurrentFrame] = useState(15);
|
|
|
|
| 7 |
const [stats, setStats] = useState(null);
|
| 8 |
const [availableModels, setAvailableModels] = useState([]);
|
| 9 |
const [currentModel, setCurrentModel] = useState('mlp');
|
| 10 |
+
const [calibration, setCalibration] = useState(null);
|
| 11 |
+
const [l2csBoost, setL2csBoost] = useState(false);
|
| 12 |
+
const [l2csBoostAvailable, setL2csBoostAvailable] = useState(false);
|
| 13 |
|
| 14 |
const localVideoRef = useRef(null);
|
| 15 |
const displayCanvasRef = useRef(null);
|
|
|
|
| 27 |
useEffect(() => {
|
| 28 |
if (!videoManager) return;
|
| 29 |
|
|
|
|
| 30 |
const originalOnStatusUpdate = videoManager.callbacks.onStatusUpdate;
|
| 31 |
videoManager.callbacks.onStatusUpdate = (isFocused) => {
|
| 32 |
setTimelineEvents(prev => {
|
|
|
|
| 37 |
if (originalOnStatusUpdate) originalOnStatusUpdate(isFocused);
|
| 38 |
};
|
| 39 |
|
| 40 |
+
videoManager.callbacks.onCalibrationUpdate = (cal) => {
|
| 41 |
+
setCalibration(cal && cal.active ? { ...cal } : null);
|
| 42 |
+
};
|
| 43 |
+
|
| 44 |
const statsInterval = setInterval(() => {
|
| 45 |
if (videoManager && videoManager.getStats) {
|
| 46 |
setStats(videoManager.getStats());
|
|
|
|
| 50 |
return () => {
|
| 51 |
if (videoManager) {
|
| 52 |
videoManager.callbacks.onStatusUpdate = originalOnStatusUpdate;
|
| 53 |
+
videoManager.callbacks.onCalibrationUpdate = null;
|
| 54 |
}
|
| 55 |
clearInterval(statsInterval);
|
| 56 |
};
|
|
|
|
| 63 |
.then(data => {
|
| 64 |
if (data.available) setAvailableModels(data.available);
|
| 65 |
if (data.current) setCurrentModel(data.current);
|
| 66 |
+
if (data.l2cs_boost !== undefined) setL2csBoost(data.l2cs_boost);
|
| 67 |
+
if (data.l2cs_boost_available !== undefined) setL2csBoostAvailable(data.l2cs_boost_available);
|
| 68 |
})
|
| 69 |
.catch(err => console.error('Failed to fetch models:', err));
|
| 70 |
}, []);
|
|
|
|
| 79 |
const result = await res.json();
|
| 80 |
if (result.updated) {
|
| 81 |
setCurrentModel(modelName);
|
| 82 |
+
setL2csBoostAvailable(modelName !== 'l2cs' && availableModels.includes('l2cs'));
|
| 83 |
+
if (modelName === 'l2cs') setL2csBoost(false);
|
| 84 |
}
|
| 85 |
} catch (err) {
|
| 86 |
console.error('Failed to switch model:', err);
|
| 87 |
}
|
| 88 |
};
|
| 89 |
|
| 90 |
+
const handleBoostToggle = async () => {
|
| 91 |
+
const next = !l2csBoost;
|
| 92 |
+
try {
|
| 93 |
+
const res = await fetch('/api/settings', {
|
| 94 |
+
method: 'PUT',
|
| 95 |
+
headers: { 'Content-Type': 'application/json' },
|
| 96 |
+
body: JSON.stringify({ l2cs_boost: next })
|
| 97 |
+
});
|
| 98 |
+
if (res.ok) setL2csBoost(next);
|
| 99 |
+
} catch (err) {
|
| 100 |
+
console.error('Failed to toggle L2CS boost:', err);
|
| 101 |
+
}
|
| 102 |
+
};
|
| 103 |
+
|
| 104 |
const handleStart = async () => {
|
| 105 |
try {
|
| 106 |
if (videoManager) {
|
|
|
|
| 468 |
{name}
|
| 469 |
</button>
|
| 470 |
))}
|
| 471 |
+
{l2csBoostAvailable && currentModel !== 'l2cs' && (
|
| 472 |
+
<button
|
| 473 |
+
onClick={handleBoostToggle}
|
| 474 |
+
style={{
|
| 475 |
+
padding: '5px 14px',
|
| 476 |
+
borderRadius: '16px',
|
| 477 |
+
border: l2csBoost ? '2px solid #f59e0b' : '1px solid #555',
|
| 478 |
+
background: l2csBoost ? 'rgba(245, 158, 11, 0.15)' : 'transparent',
|
| 479 |
+
color: l2csBoost ? '#f59e0b' : '#888',
|
| 480 |
+
fontSize: '11px',
|
| 481 |
+
fontWeight: l2csBoost ? 'bold' : 'normal',
|
| 482 |
+
cursor: 'pointer',
|
| 483 |
+
transition: 'all 0.2s',
|
| 484 |
+
marginLeft: '4px',
|
| 485 |
+
}}
|
| 486 |
+
>
|
| 487 |
+
{l2csBoost ? 'GAZE ON' : 'GAZE'}
|
| 488 |
+
</button>
|
| 489 |
+
)}
|
| 490 |
+
{(currentModel === 'l2cs' || l2csBoost) && stats && stats.isStreaming && (
|
| 491 |
+
<button
|
| 492 |
+
onClick={() => videoManager && videoManager.startCalibration()}
|
| 493 |
+
style={{
|
| 494 |
+
padding: '5px 14px',
|
| 495 |
+
borderRadius: '16px',
|
| 496 |
+
border: '1px solid #4ade80',
|
| 497 |
+
background: 'transparent',
|
| 498 |
+
color: '#4ade80',
|
| 499 |
+
fontSize: '12px',
|
| 500 |
+
fontWeight: 'bold',
|
| 501 |
+
cursor: 'pointer',
|
| 502 |
+
transition: 'all 0.2s',
|
| 503 |
+
marginLeft: '4px',
|
| 504 |
+
}}
|
| 505 |
+
>
|
| 506 |
+
Calibrate
|
| 507 |
+
</button>
|
| 508 |
+
)}
|
| 509 |
</section>
|
| 510 |
)}
|
| 511 |
|
|
|
|
| 576 |
onChange={(e) => handleFrameChange(e.target.value)}
|
| 577 |
/>
|
| 578 |
</section>
|
| 579 |
+
|
| 580 |
+
{/* Calibration overlay (fixed fullscreen, must be outside overflow:hidden containers) */}
|
| 581 |
+
<CalibrationOverlay calibration={calibration} videoManager={videoManager} />
|
| 582 |
</main>
|
| 583 |
);
|
| 584 |
}
|
src/utils/VideoManagerLocal.js
CHANGED
|
@@ -39,6 +39,17 @@ export class VideoManagerLocal {
|
|
| 39 |
this.lastNotificationTime = null;
|
| 40 |
this.notificationCooldown = 60000;
|
| 41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
// 性能统计
|
| 43 |
this.stats = {
|
| 44 |
framesSent: 0,
|
|
@@ -73,8 +84,8 @@ export class VideoManagerLocal {
|
|
| 73 |
|
| 74 |
// 创建用于截图的 canvas (smaller for faster encode + transfer)
|
| 75 |
this.canvas = document.createElement('canvas');
|
| 76 |
-
this.canvas.width =
|
| 77 |
-
this.canvas.height =
|
| 78 |
|
| 79 |
console.log('Local camera initialized');
|
| 80 |
return true;
|
|
@@ -188,7 +199,7 @@ export class VideoManagerLocal {
|
|
| 188 |
this.ws.send(blob);
|
| 189 |
this.stats.framesSent++;
|
| 190 |
}
|
| 191 |
-
}, 'image/jpeg', 0.
|
| 192 |
} catch (error) {
|
| 193 |
this._sendingBlob = false;
|
| 194 |
console.error('Capture error:', error);
|
|
@@ -253,6 +264,19 @@ export class VideoManagerLocal {
|
|
| 253 |
ctx.textAlign = 'left';
|
| 254 |
}
|
| 255 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 256 |
// Performance stats
|
| 257 |
ctx.fillStyle = 'rgba(0,0,0,0.5)';
|
| 258 |
ctx.fillRect(0, h - 25, w, 25);
|
|
@@ -321,6 +345,9 @@ export class VideoManagerLocal {
|
|
| 321 |
mar: data.mar,
|
| 322 |
sf: data.sf,
|
| 323 |
se: data.se,
|
|
|
|
|
|
|
|
|
|
| 324 |
};
|
| 325 |
this.drawDetectionResult(detectionData);
|
| 326 |
break;
|
|
@@ -338,6 +365,51 @@ export class VideoManagerLocal {
|
|
| 338 |
this.sessionStartTime = null;
|
| 339 |
break;
|
| 340 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 341 |
case 'error':
|
| 342 |
console.error('Server error:', data.message);
|
| 343 |
break;
|
|
@@ -347,6 +419,28 @@ export class VideoManagerLocal {
|
|
| 347 |
}
|
| 348 |
}
|
| 349 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 350 |
// Face mesh landmark index groups (matches live_demo.py)
|
| 351 |
static FACE_OVAL = [10,338,297,332,284,251,389,356,454,323,361,288,397,365,379,378,400,377,152,148,176,149,150,136,172,58,132,93,234,127,162,21,54,103,67,109,10];
|
| 352 |
static LEFT_EYE = [33,7,163,144,145,153,154,155,133,173,157,158,159,160,161,246];
|
|
|
|
| 39 |
this.lastNotificationTime = null;
|
| 40 |
this.notificationCooldown = 60000;
|
| 41 |
|
| 42 |
+
// Calibration state
|
| 43 |
+
this.calibration = {
|
| 44 |
+
active: false,
|
| 45 |
+
collecting: false,
|
| 46 |
+
target: null,
|
| 47 |
+
index: 0,
|
| 48 |
+
numPoints: 0,
|
| 49 |
+
done: false,
|
| 50 |
+
success: false,
|
| 51 |
+
};
|
| 52 |
+
|
| 53 |
// 性能统计
|
| 54 |
this.stats = {
|
| 55 |
framesSent: 0,
|
|
|
|
| 84 |
|
| 85 |
// 创建用于截图的 canvas (smaller for faster encode + transfer)
|
| 86 |
this.canvas = document.createElement('canvas');
|
| 87 |
+
this.canvas.width = 640;
|
| 88 |
+
this.canvas.height = 480;
|
| 89 |
|
| 90 |
console.log('Local camera initialized');
|
| 91 |
return true;
|
|
|
|
| 199 |
this.ws.send(blob);
|
| 200 |
this.stats.framesSent++;
|
| 201 |
}
|
| 202 |
+
}, 'image/jpeg', 0.75);
|
| 203 |
} catch (error) {
|
| 204 |
this._sendingBlob = false;
|
| 205 |
console.error('Capture error:', error);
|
|
|
|
| 264 |
ctx.textAlign = 'left';
|
| 265 |
}
|
| 266 |
}
|
| 267 |
+
// Gaze pointer (L2CS + calibration)
|
| 268 |
+
if (data && data.gaze_x !== undefined && data.gaze_y !== undefined) {
|
| 269 |
+
const gx = data.gaze_x * w;
|
| 270 |
+
const gy = data.gaze_y * h;
|
| 271 |
+
ctx.beginPath();
|
| 272 |
+
ctx.arc(gx, gy, 8, 0, 2 * Math.PI);
|
| 273 |
+
ctx.fillStyle = data.on_screen ? 'rgba(0, 200, 255, 0.7)' : 'rgba(255, 80, 80, 0.5)';
|
| 274 |
+
ctx.fill();
|
| 275 |
+
ctx.strokeStyle = '#FFFFFF';
|
| 276 |
+
ctx.lineWidth = 2;
|
| 277 |
+
ctx.stroke();
|
| 278 |
+
}
|
| 279 |
+
|
| 280 |
// Performance stats
|
| 281 |
ctx.fillStyle = 'rgba(0,0,0,0.5)';
|
| 282 |
ctx.fillRect(0, h - 25, w, 25);
|
|
|
|
| 345 |
mar: data.mar,
|
| 346 |
sf: data.sf,
|
| 347 |
se: data.se,
|
| 348 |
+
gaze_x: data.gaze_x,
|
| 349 |
+
gaze_y: data.gaze_y,
|
| 350 |
+
on_screen: data.on_screen,
|
| 351 |
};
|
| 352 |
this.drawDetectionResult(detectionData);
|
| 353 |
break;
|
|
|
|
| 365 |
this.sessionStartTime = null;
|
| 366 |
break;
|
| 367 |
|
| 368 |
+
case 'calibration_started':
|
| 369 |
+
this.calibration = {
|
| 370 |
+
active: true,
|
| 371 |
+
collecting: true,
|
| 372 |
+
target: data.target,
|
| 373 |
+
index: data.index,
|
| 374 |
+
numPoints: data.num_points,
|
| 375 |
+
done: false,
|
| 376 |
+
success: false,
|
| 377 |
+
};
|
| 378 |
+
if (this.callbacks.onCalibrationUpdate) {
|
| 379 |
+
this.callbacks.onCalibrationUpdate({ ...this.calibration });
|
| 380 |
+
}
|
| 381 |
+
break;
|
| 382 |
+
|
| 383 |
+
case 'calibration_point':
|
| 384 |
+
this.calibration.target = data.target;
|
| 385 |
+
this.calibration.index = data.index;
|
| 386 |
+
if (this.callbacks.onCalibrationUpdate) {
|
| 387 |
+
this.callbacks.onCalibrationUpdate({ ...this.calibration });
|
| 388 |
+
}
|
| 389 |
+
break;
|
| 390 |
+
|
| 391 |
+
case 'calibration_done':
|
| 392 |
+
this.calibration.collecting = false;
|
| 393 |
+
this.calibration.done = true;
|
| 394 |
+
this.calibration.success = data.success;
|
| 395 |
+
if (this.callbacks.onCalibrationUpdate) {
|
| 396 |
+
this.callbacks.onCalibrationUpdate({ ...this.calibration });
|
| 397 |
+
}
|
| 398 |
+
setTimeout(() => {
|
| 399 |
+
this.calibration.active = false;
|
| 400 |
+
if (this.callbacks.onCalibrationUpdate) {
|
| 401 |
+
this.callbacks.onCalibrationUpdate({ ...this.calibration });
|
| 402 |
+
}
|
| 403 |
+
}, 2000);
|
| 404 |
+
break;
|
| 405 |
+
|
| 406 |
+
case 'calibration_cancelled':
|
| 407 |
+
this.calibration = { active: false, collecting: false, target: null, index: 0, numPoints: 0, done: false, success: false };
|
| 408 |
+
if (this.callbacks.onCalibrationUpdate) {
|
| 409 |
+
this.callbacks.onCalibrationUpdate({ ...this.calibration });
|
| 410 |
+
}
|
| 411 |
+
break;
|
| 412 |
+
|
| 413 |
case 'error':
|
| 414 |
console.error('Server error:', data.message);
|
| 415 |
break;
|
|
|
|
| 419 |
}
|
| 420 |
}
|
| 421 |
|
| 422 |
+
startCalibration() {
|
| 423 |
+
if (this.ws && this.ws.readyState === WebSocket.OPEN) {
|
| 424 |
+
this.ws.send(JSON.stringify({ type: 'calibration_start' }));
|
| 425 |
+
}
|
| 426 |
+
}
|
| 427 |
+
|
| 428 |
+
nextCalibrationPoint() {
|
| 429 |
+
if (this.ws && this.ws.readyState === WebSocket.OPEN) {
|
| 430 |
+
this.ws.send(JSON.stringify({ type: 'calibration_next' }));
|
| 431 |
+
}
|
| 432 |
+
}
|
| 433 |
+
|
| 434 |
+
cancelCalibration() {
|
| 435 |
+
if (this.ws && this.ws.readyState === WebSocket.OPEN) {
|
| 436 |
+
this.ws.send(JSON.stringify({ type: 'calibration_cancel' }));
|
| 437 |
+
}
|
| 438 |
+
this.calibration = { active: false, collecting: false, target: null, index: 0, numPoints: 0, done: false, success: false };
|
| 439 |
+
if (this.callbacks.onCalibrationUpdate) {
|
| 440 |
+
this.callbacks.onCalibrationUpdate({ ...this.calibration });
|
| 441 |
+
}
|
| 442 |
+
}
|
| 443 |
+
|
| 444 |
// Face mesh landmark index groups (matches live_demo.py)
|
| 445 |
static FACE_OVAL = [10,338,297,332,284,251,389,356,454,323,361,288,397,365,379,378,400,377,152,148,176,149,150,136,172,58,132,93,234,127,162,21,54,103,67,109,10];
|
| 446 |
static LEFT_EYE = [33,7,163,144,145,153,154,155,133,173,157,158,159,160,161,246];
|
ui/pipeline.py
CHANGED
|
@@ -3,6 +3,7 @@ import glob
|
|
| 3 |
import json
|
| 4 |
import math
|
| 5 |
import os
|
|
|
|
| 6 |
import sys
|
| 7 |
|
| 8 |
import numpy as np
|
|
@@ -49,10 +50,12 @@ def _clip_features(vec):
|
|
| 49 |
|
| 50 |
|
| 51 |
class _OutputSmoother:
|
| 52 |
-
|
|
|
|
| 53 |
|
| 54 |
-
def __init__(self,
|
| 55 |
-
self.
|
|
|
|
| 56 |
self._grace = grace_frames
|
| 57 |
self._score = 0.5
|
| 58 |
self._no_face = 0
|
|
@@ -61,14 +64,15 @@ class _OutputSmoother:
|
|
| 61 |
self._score = 0.5
|
| 62 |
self._no_face = 0
|
| 63 |
|
| 64 |
-
def update(self, raw_score
|
| 65 |
if face_detected:
|
| 66 |
self._no_face = 0
|
| 67 |
-
|
|
|
|
| 68 |
else:
|
| 69 |
self._no_face += 1
|
| 70 |
if self._no_face > self._grace:
|
| 71 |
-
self._score *= 0.
|
| 72 |
return self._score
|
| 73 |
|
| 74 |
|
|
@@ -640,3 +644,141 @@ class XGBoostPipeline:
|
|
| 640 |
|
| 641 |
def __exit__(self, *args):
|
| 642 |
self.close()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
import json
|
| 4 |
import math
|
| 5 |
import os
|
| 6 |
+
import pathlib
|
| 7 |
import sys
|
| 8 |
|
| 9 |
import numpy as np
|
|
|
|
| 50 |
|
| 51 |
|
| 52 |
class _OutputSmoother:
|
| 53 |
+
# Asymmetric EMA: rises fast (recognise focus), falls slower (avoid flicker).
|
| 54 |
+
# Grace period holds score steady for a few frames when face is lost.
|
| 55 |
|
| 56 |
+
def __init__(self, alpha_up=0.55, alpha_down=0.45, grace_frames=10):
|
| 57 |
+
self._alpha_up = alpha_up
|
| 58 |
+
self._alpha_down = alpha_down
|
| 59 |
self._grace = grace_frames
|
| 60 |
self._score = 0.5
|
| 61 |
self._no_face = 0
|
|
|
|
| 64 |
self._score = 0.5
|
| 65 |
self._no_face = 0
|
| 66 |
|
| 67 |
+
def update(self, raw_score, face_detected):
|
| 68 |
if face_detected:
|
| 69 |
self._no_face = 0
|
| 70 |
+
alpha = self._alpha_up if raw_score > self._score else self._alpha_down
|
| 71 |
+
self._score += alpha * (raw_score - self._score)
|
| 72 |
else:
|
| 73 |
self._no_face += 1
|
| 74 |
if self._no_face > self._grace:
|
| 75 |
+
self._score *= 0.80
|
| 76 |
return self._score
|
| 77 |
|
| 78 |
|
|
|
|
| 644 |
|
| 645 |
def __exit__(self, *args):
|
| 646 |
self.close()
|
| 647 |
+
|
| 648 |
+
|
| 649 |
+
def _resolve_l2cs_weights():
|
| 650 |
+
for p in [
|
| 651 |
+
os.path.join(_PROJECT_ROOT, "models", "L2CS-Net", "models", "L2CSNet_gaze360.pkl"),
|
| 652 |
+
os.path.join(_PROJECT_ROOT, "models", "L2CSNet_gaze360.pkl"),
|
| 653 |
+
os.path.join(_PROJECT_ROOT, "checkpoints", "L2CSNet_gaze360.pkl"),
|
| 654 |
+
]:
|
| 655 |
+
if os.path.isfile(p):
|
| 656 |
+
return p
|
| 657 |
+
return None
|
| 658 |
+
|
| 659 |
+
|
| 660 |
+
def is_l2cs_weights_available():
|
| 661 |
+
return _resolve_l2cs_weights() is not None
|
| 662 |
+
|
| 663 |
+
|
| 664 |
+
class L2CSPipeline:
|
| 665 |
+
# Uses in-tree l2cs.Pipeline (RetinaFace + ResNet50) for gaze estimation
|
| 666 |
+
# and MediaPipe for head pose, EAR, MAR, and roll de-rotation.
|
| 667 |
+
|
| 668 |
+
YAW_THRESHOLD = 22.0
|
| 669 |
+
PITCH_THRESHOLD = 20.0
|
| 670 |
+
|
| 671 |
+
def __init__(self, weights_path=None, arch="ResNet50", device="cpu",
|
| 672 |
+
threshold=0.52, detector=None):
|
| 673 |
+
resolved = weights_path or _resolve_l2cs_weights()
|
| 674 |
+
if resolved is None or not os.path.isfile(resolved):
|
| 675 |
+
raise FileNotFoundError(
|
| 676 |
+
"L2CS weights not found. Place L2CSNet_gaze360.pkl in "
|
| 677 |
+
"models/L2CS-Net/models/ or checkpoints/"
|
| 678 |
+
)
|
| 679 |
+
|
| 680 |
+
# add in-tree L2CS-Net to import path
|
| 681 |
+
l2cs_root = os.path.join(_PROJECT_ROOT, "models", "L2CS-Net")
|
| 682 |
+
if l2cs_root not in sys.path:
|
| 683 |
+
sys.path.insert(0, l2cs_root)
|
| 684 |
+
from l2cs import Pipeline as _L2CSPipeline
|
| 685 |
+
|
| 686 |
+
import torch
|
| 687 |
+
# bypass upstream select_device bug by constructing torch.device directly
|
| 688 |
+
self._pipeline = _L2CSPipeline(
|
| 689 |
+
weights=pathlib.Path(resolved), arch=arch, device=torch.device(device),
|
| 690 |
+
)
|
| 691 |
+
|
| 692 |
+
self._detector = detector or FaceMeshDetector()
|
| 693 |
+
self._owns_detector = detector is None
|
| 694 |
+
self._head_pose = HeadPoseEstimator()
|
| 695 |
+
self.head_pose = self._head_pose
|
| 696 |
+
self._eye_scorer = EyeBehaviourScorer()
|
| 697 |
+
self._threshold = threshold
|
| 698 |
+
self._smoother = _OutputSmoother()
|
| 699 |
+
|
| 700 |
+
print(
|
| 701 |
+
f"[L2CS] Loaded {resolved} | arch={arch} device={device} "
|
| 702 |
+
f"yaw_thresh={self.YAW_THRESHOLD} pitch_thresh={self.PITCH_THRESHOLD} "
|
| 703 |
+
f"threshold={threshold}"
|
| 704 |
+
)
|
| 705 |
+
|
| 706 |
+
@staticmethod
|
| 707 |
+
def _derotate_gaze(pitch_rad, yaw_rad, roll_deg):
|
| 708 |
+
# remove head roll so tilted-but-looking-at-screen reads as (0,0)
|
| 709 |
+
roll_rad = -math.radians(roll_deg)
|
| 710 |
+
cos_r, sin_r = math.cos(roll_rad), math.sin(roll_rad)
|
| 711 |
+
return (yaw_rad * sin_r + pitch_rad * cos_r,
|
| 712 |
+
yaw_rad * cos_r - pitch_rad * sin_r)
|
| 713 |
+
|
| 714 |
+
def process_frame(self, bgr_frame):
|
| 715 |
+
landmarks = self._detector.process(bgr_frame)
|
| 716 |
+
h, w = bgr_frame.shape[:2]
|
| 717 |
+
|
| 718 |
+
out = {
|
| 719 |
+
"landmarks": landmarks, "is_focused": False, "raw_score": 0.0,
|
| 720 |
+
"s_face": 0.0, "s_eye": 0.0, "gaze_pitch": None, "gaze_yaw": None,
|
| 721 |
+
"yaw": None, "pitch": None, "roll": None, "mar": None, "is_yawning": False,
|
| 722 |
+
}
|
| 723 |
+
|
| 724 |
+
# MediaPipe: head pose, eye/mouth scores
|
| 725 |
+
roll_deg = 0.0
|
| 726 |
+
if landmarks is not None:
|
| 727 |
+
angles = self._head_pose.estimate(landmarks, w, h)
|
| 728 |
+
if angles is not None:
|
| 729 |
+
out["yaw"], out["pitch"], out["roll"] = angles
|
| 730 |
+
roll_deg = angles[2]
|
| 731 |
+
out["s_face"] = self._head_pose.score(landmarks, w, h)
|
| 732 |
+
out["s_eye"] = self._eye_scorer.score(landmarks)
|
| 733 |
+
out["mar"] = compute_mar(landmarks)
|
| 734 |
+
out["is_yawning"] = out["mar"] > MAR_YAWN_THRESHOLD
|
| 735 |
+
|
| 736 |
+
# L2CS gaze (uses its own RetinaFace detector internally)
|
| 737 |
+
results = self._pipeline.step(bgr_frame)
|
| 738 |
+
|
| 739 |
+
if results is None or results.pitch.shape[0] == 0:
|
| 740 |
+
smoothed = self._smoother.update(0.0, landmarks is not None)
|
| 741 |
+
out["raw_score"] = smoothed
|
| 742 |
+
out["is_focused"] = smoothed >= self._threshold
|
| 743 |
+
return out
|
| 744 |
+
|
| 745 |
+
pitch_rad = float(results.pitch[0])
|
| 746 |
+
yaw_rad = float(results.yaw[0])
|
| 747 |
+
|
| 748 |
+
pitch_rad, yaw_rad = self._derotate_gaze(pitch_rad, yaw_rad, roll_deg)
|
| 749 |
+
out["gaze_pitch"] = pitch_rad
|
| 750 |
+
out["gaze_yaw"] = yaw_rad
|
| 751 |
+
|
| 752 |
+
yaw_deg = abs(math.degrees(yaw_rad))
|
| 753 |
+
pitch_deg = abs(math.degrees(pitch_rad))
|
| 754 |
+
|
| 755 |
+
# fall back to L2CS angles if MediaPipe didn't produce head pose
|
| 756 |
+
out["yaw"] = out.get("yaw") or math.degrees(yaw_rad)
|
| 757 |
+
out["pitch"] = out.get("pitch") or math.degrees(pitch_rad)
|
| 758 |
+
|
| 759 |
+
# cosine scoring: 1.0 at centre, 0.0 at threshold
|
| 760 |
+
yaw_t = min(yaw_deg / self.YAW_THRESHOLD, 1.0)
|
| 761 |
+
pitch_t = min(pitch_deg / self.PITCH_THRESHOLD, 1.0)
|
| 762 |
+
yaw_score = 0.5 * (1.0 + math.cos(math.pi * yaw_t))
|
| 763 |
+
pitch_score = 0.5 * (1.0 + math.cos(math.pi * pitch_t))
|
| 764 |
+
gaze_score = 0.55 * yaw_score + 0.45 * pitch_score
|
| 765 |
+
|
| 766 |
+
if out["is_yawning"]:
|
| 767 |
+
gaze_score = 0.0
|
| 768 |
+
|
| 769 |
+
out["raw_score"] = self._smoother.update(float(gaze_score), True)
|
| 770 |
+
out["is_focused"] = out["raw_score"] >= self._threshold
|
| 771 |
+
return out
|
| 772 |
+
|
| 773 |
+
def reset_session(self):
|
| 774 |
+
self._smoother.reset()
|
| 775 |
+
|
| 776 |
+
def close(self):
|
| 777 |
+
if self._owns_detector:
|
| 778 |
+
self._detector.close()
|
| 779 |
+
|
| 780 |
+
def __enter__(self):
|
| 781 |
+
return self
|
| 782 |
+
|
| 783 |
+
def __exit__(self, *args):
|
| 784 |
+
self.close()
|