k22056537 commited on
Commit
df9f1dd
·
1 Parent(s): da26163

feat: data collection script, explorer notebook, sample sessions

Browse files
data_preparation/collected/session_20260217_111435.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e9689671c2ee3f9263b142afd8efd3d3c62384087a496d99fc969c5d8d9d961d
3
+ size 45316
data_preparation/collected/session_20260217_112240.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2adbdd61854ff37a0e7dfe6a3fc9980ff6f2534430842912de73bd7f97e3c261
3
+ size 182740
data_preparation/explore_collected_data.ipynb ADDED
@@ -0,0 +1,414 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# FocusGuard — Collected Data Explorer\n",
8
+ "Load `.npz` files from `collect_features.py` and inspect the data before training."
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "code",
13
+ "execution_count": null,
14
+ "metadata": {},
15
+ "outputs": [
16
+ {
17
+ "ename": "FileNotFoundError",
18
+ "evalue": "No .npz files in /content/collected — run collect_features.py first",
19
+ "output_type": "error",
20
+ "traceback": [
21
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
22
+ "\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)",
23
+ "\u001b[0;32m/tmp/ipython-input-251140757.py\u001b[0m in \u001b[0;36m<cell line: 0>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mnpz_files\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 11\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mFileNotFoundError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"No .npz files in {COLLECTED_DIR} — run collect_features.py first\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 12\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[0mNPZ_PATH\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnpz_files\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;31m# latest file\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
24
+ "\u001b[0;31mFileNotFoundError\u001b[0m: No .npz files in /content/collected — run collect_features.py first"
25
+ ]
26
+ }
27
+ ],
28
+ "source": [
29
+ "import numpy as np\n",
30
+ "import matplotlib.pyplot as plt\n",
31
+ "import os\n",
32
+ "import glob\n",
33
+ "\n",
34
+ "# auto-find the latest .npz in collected/, or set manually\n",
35
+ "COLLECTED_DIR = os.path.join(os.path.dirname(os.path.abspath(\"__file__\")), \"collected\")\n",
36
+ "npz_files = sorted(glob.glob(os.path.join(COLLECTED_DIR, \"*.npz\")))\n",
37
+ "\n",
38
+ "if not npz_files:\n",
39
+ " raise FileNotFoundError(f\"No .npz files in {COLLECTED_DIR} — run collect_features.py first\")\n",
40
+ "\n",
41
+ "NPZ_PATH = npz_files[-1] # latest file\n",
42
+ "print(f\"Using: {NPZ_PATH}\")\n",
43
+ "\n",
44
+ "data = np.load(NPZ_PATH, allow_pickle=True)\n",
45
+ "features = data['features']\n",
46
+ "labels = data['labels']\n",
47
+ "names = list(data['feature_names'])\n",
48
+ "\n",
49
+ "print(f\"Loaded: {NPZ_PATH}\")\n",
50
+ "print(f\"Samples: {len(labels)}\")\n",
51
+ "print(f\"Features: {features.shape[1]} -> {names}\")\n",
52
+ "print(f\"Labels: 0={int((labels==0).sum())}, 1={int((labels==1).sum())}\")"
53
+ ]
54
+ },
55
+ {
56
+ "cell_type": "markdown",
57
+ "metadata": {},
58
+ "source": [
59
+ "## 1. Basic Stats"
60
+ ]
61
+ },
62
+ {
63
+ "cell_type": "code",
64
+ "execution_count": null,
65
+ "metadata": {},
66
+ "outputs": [],
67
+ "source": [
68
+ "import pandas as pd\n",
69
+ "\n",
70
+ "df = pd.DataFrame(features, columns=names)\n",
71
+ "df['label'] = labels\n",
72
+ "\n",
73
+ "print(\"=\" * 60)\n",
74
+ "print(\"FEATURE STATISTICS\")\n",
75
+ "print(\"=\" * 60)\n",
76
+ "df.describe().round(4)"
77
+ ]
78
+ },
79
+ {
80
+ "cell_type": "code",
81
+ "execution_count": null,
82
+ "metadata": {},
83
+ "outputs": [],
84
+ "source": [
85
+ "# NaN check\n",
86
+ "nan_counts = df.isna().sum()\n",
87
+ "if nan_counts.sum() == 0:\n",
88
+ " print(\"No NaN values found\")\n",
89
+ "else:\n",
90
+ " print(\"NaN counts:\")\n",
91
+ " print(nan_counts[nan_counts > 0])"
92
+ ]
93
+ },
94
+ {
95
+ "cell_type": "markdown",
96
+ "metadata": {},
97
+ "source": [
98
+ "## 2. Label Distribution"
99
+ ]
100
+ },
101
+ {
102
+ "cell_type": "code",
103
+ "execution_count": null,
104
+ "metadata": {},
105
+ "outputs": [],
106
+ "source": [
107
+ "n0 = int((labels == 0).sum())\n",
108
+ "n1 = int((labels == 1).sum())\n",
109
+ "total = len(labels)\n",
110
+ "\n",
111
+ "fig, axes = plt.subplots(1, 2, figsize=(10, 4))\n",
112
+ "\n",
113
+ "# bar chart\n",
114
+ "axes[0].bar(['Unfocused (0)', 'Focused (1)'], [n0, n1], color=['#EF476F', '#06D6A0'])\n",
115
+ "axes[0].set_ylabel('Samples')\n",
116
+ "axes[0].set_title('Label Distribution')\n",
117
+ "for i, v in enumerate([n0, n1]):\n",
118
+ " axes[0].text(i, v + total*0.01, f'{v} ({v/total*100:.1f}%)', ha='center', fontsize=10)\n",
119
+ "\n",
120
+ "# label over time\n",
121
+ "axes[1].plot(labels, color='#00B4D8', linewidth=0.5)\n",
122
+ "axes[1].fill_between(range(len(labels)), labels, alpha=0.3, color='#06D6A0')\n",
123
+ "axes[1].set_xlabel('Frame')\n",
124
+ "axes[1].set_ylabel('Label')\n",
125
+ "axes[1].set_title('Label Over Time')\n",
126
+ "axes[1].set_yticks([0, 1])\n",
127
+ "axes[1].set_yticklabels(['Unfocused', 'Focused'])\n",
128
+ "\n",
129
+ "plt.tight_layout()\n",
130
+ "plt.show()\n",
131
+ "\n",
132
+ "# transitions\n",
133
+ "transitions = int(np.sum(np.diff(labels) != 0))\n",
134
+ "print(f\"Transitions: {transitions}\")\n",
135
+ "print(f\"Avg segment: {total/max(transitions,1):.0f} frames ({total/max(transitions,1)/30:.1f}s)\")\n",
136
+ "if transitions < 10:\n",
137
+ " print(\"⚠️ Too few transitions — switch every 10-30s when re-recording\")"
138
+ ]
139
+ },
140
+ {
141
+ "cell_type": "markdown",
142
+ "metadata": {},
143
+ "source": [
144
+ "## 3. Feature Distributions (Focused vs Unfocused)"
145
+ ]
146
+ },
147
+ {
148
+ "cell_type": "code",
149
+ "execution_count": null,
150
+ "metadata": {},
151
+ "outputs": [],
152
+ "source": [
153
+ "n_features = features.shape[1]\n",
154
+ "cols = 3\n",
155
+ "rows = (n_features + cols - 1) // cols\n",
156
+ "\n",
157
+ "fig, axes = plt.subplots(rows, cols, figsize=(14, rows * 2.5))\n",
158
+ "axes = axes.flatten()\n",
159
+ "\n",
160
+ "for i in range(n_features):\n",
161
+ " ax = axes[i]\n",
162
+ " f0 = features[labels == 0, i]\n",
163
+ " f1 = features[labels == 1, i]\n",
164
+ " ax.hist(f0, bins=40, alpha=0.6, color='#EF476F', label='Unfocused', density=True)\n",
165
+ " ax.hist(f1, bins=40, alpha=0.6, color='#06D6A0', label='Focused', density=True)\n",
166
+ " ax.set_title(names[i], fontsize=10, fontweight='bold')\n",
167
+ " ax.tick_params(labelsize=8)\n",
168
+ " if i == 0:\n",
169
+ " ax.legend(fontsize=8)\n",
170
+ "\n",
171
+ "# hide empty axes\n",
172
+ "for i in range(n_features, len(axes)):\n",
173
+ " axes[i].set_visible(False)\n",
174
+ "\n",
175
+ "plt.suptitle('Feature Distributions by Label', fontsize=14, fontweight='bold', y=1.01)\n",
176
+ "plt.tight_layout()\n",
177
+ "plt.show()"
178
+ ]
179
+ },
180
+ {
181
+ "cell_type": "markdown",
182
+ "metadata": {},
183
+ "source": [
184
+ "## 4. Feature-Label Correlations"
185
+ ]
186
+ },
187
+ {
188
+ "cell_type": "code",
189
+ "execution_count": null,
190
+ "metadata": {},
191
+ "outputs": [],
192
+ "source": [
193
+ "correlations = [np.corrcoef(features[:, i], labels)[0, 1] for i in range(n_features)]\n",
194
+ "sort_idx = np.argsort(np.abs(correlations))[::-1]\n",
195
+ "\n",
196
+ "fig, ax = plt.subplots(figsize=(10, 5))\n",
197
+ "colors = ['#06D6A0' if c > 0 else '#EF476F' for c in [correlations[i] for i in sort_idx]]\n",
198
+ "bars = ax.barh([names[i] for i in sort_idx],\n",
199
+ " [correlations[i] for i in sort_idx],\n",
200
+ " color=colors)\n",
201
+ "ax.set_xlabel('Correlation with Label (focused=1)')\n",
202
+ "ax.set_title('Feature-Label Correlations (sorted by |r|)')\n",
203
+ "ax.axvline(0, color='gray', linewidth=0.5)\n",
204
+ "\n",
205
+ "for bar, idx in zip(bars, sort_idx):\n",
206
+ " r = correlations[idx]\n",
207
+ " ax.text(r + (0.01 if r >= 0 else -0.01), bar.get_y() + bar.get_height()/2,\n",
208
+ " f'{r:.3f}', va='center', ha='left' if r >= 0 else 'right', fontsize=9)\n",
209
+ "\n",
210
+ "plt.tight_layout()\n",
211
+ "plt.show()\n",
212
+ "\n",
213
+ "print(\"\\nTop predictive features:\")\n",
214
+ "for i in sort_idx[:5]:\n",
215
+ " print(f\" {names[i]:<20} r = {correlations[i]:+.4f}\")"
216
+ ]
217
+ },
218
+ {
219
+ "cell_type": "markdown",
220
+ "metadata": {},
221
+ "source": [
222
+ "## 5. Feature Correlation Matrix"
223
+ ]
224
+ },
225
+ {
226
+ "cell_type": "code",
227
+ "execution_count": null,
228
+ "metadata": {},
229
+ "outputs": [],
230
+ "source": [
231
+ "corr_matrix = np.corrcoef(features.T)\n",
232
+ "\n",
233
+ "fig, ax = plt.subplots(figsize=(10, 8))\n",
234
+ "im = ax.imshow(corr_matrix, cmap='RdBu_r', vmin=-1, vmax=1)\n",
235
+ "ax.set_xticks(range(n_features))\n",
236
+ "ax.set_yticks(range(n_features))\n",
237
+ "ax.set_xticklabels(names, rotation=45, ha='right', fontsize=9)\n",
238
+ "ax.set_yticklabels(names, fontsize=9)\n",
239
+ "ax.set_title('Feature Correlation Matrix')\n",
240
+ "plt.colorbar(im, ax=ax, shrink=0.8)\n",
241
+ "\n",
242
+ "# annotate\n",
243
+ "for i in range(n_features):\n",
244
+ " for j in range(n_features):\n",
245
+ " val = corr_matrix[i, j]\n",
246
+ " if abs(val) > 0.5 and i != j:\n",
247
+ " ax.text(j, i, f'{val:.2f}', ha='center', va='center', fontsize=7,\n",
248
+ " color='white' if abs(val) > 0.7 else 'black')\n",
249
+ "\n",
250
+ "plt.tight_layout()\n",
251
+ "plt.show()"
252
+ ]
253
+ },
254
+ {
255
+ "cell_type": "markdown",
256
+ "metadata": {},
257
+ "source": [
258
+ "## 6. Features Over Time"
259
+ ]
260
+ },
261
+ {
262
+ "cell_type": "code",
263
+ "execution_count": null,
264
+ "metadata": {},
265
+ "outputs": [],
266
+ "source": [
267
+ "# Plot key features over time with label shading\n",
268
+ "key_features = ['s_face', 's_eye', 'ear_avg', 'yaw', 'pitch']\n",
269
+ "# filter to only features that exist in this file\n",
270
+ "key_features = [f for f in key_features if f in names]\n",
271
+ "\n",
272
+ "fig, axes = plt.subplots(len(key_features) + 1, 1, figsize=(14, (len(key_features)+1) * 1.8),\n",
273
+ " sharex=True)\n",
274
+ "\n",
275
+ "# label timeline\n",
276
+ "axes[0].fill_between(range(len(labels)), labels, alpha=0.4, color='#06D6A0', step='mid')\n",
277
+ "axes[0].set_ylabel('Label')\n",
278
+ "axes[0].set_yticks([0, 1])\n",
279
+ "axes[0].set_yticklabels(['Unfocused', 'Focused'], fontsize=9)\n",
280
+ "axes[0].set_title('Label + Key Features Over Time', fontsize=12, fontweight='bold')\n",
281
+ "\n",
282
+ "for i, feat in enumerate(key_features):\n",
283
+ " idx = names.index(feat)\n",
284
+ " ax = axes[i + 1]\n",
285
+ " ax.plot(features[:, idx], linewidth=0.8, color='#00B4D8')\n",
286
+ " # shade focused regions\n",
287
+ " ax.fill_between(range(len(labels)), ax.get_ylim()[0], ax.get_ylim()[1],\n",
288
+ " where=labels == 1, alpha=0.1, color='green')\n",
289
+ " ax.set_ylabel(feat, fontsize=9)\n",
290
+ "\n",
291
+ "axes[-1].set_xlabel('Frame')\n",
292
+ "plt.tight_layout()\n",
293
+ "plt.show()"
294
+ ]
295
+ },
296
+ {
297
+ "cell_type": "markdown",
298
+ "metadata": {},
299
+ "source": [
300
+ "## 7. Quality Summary"
301
+ ]
302
+ },
303
+ {
304
+ "cell_type": "code",
305
+ "execution_count": null,
306
+ "metadata": {},
307
+ "outputs": [],
308
+ "source": [
309
+ "duration_sec = len(labels) / 30.0\n",
310
+ "balance = n1 / max(total, 1)\n",
311
+ "\n",
312
+ "checks = {\n",
313
+ " 'Duration >= 2 min': duration_sec >= 120,\n",
314
+ " 'Samples >= 3000': total >= 3000,\n",
315
+ " 'Balance 30-70%': 0.3 <= balance <= 0.7,\n",
316
+ " 'Transitions >= 10': transitions >= 10,\n",
317
+ " 'No NaN values': int(np.isnan(features).sum()) == 0,\n",
318
+ " 'No constant features': all(features[:, i].std() > 0.001 for i in range(n_features)),\n",
319
+ "}\n",
320
+ "\n",
321
+ "print(\"DATA QUALITY CHECKLIST\")\n",
322
+ "print(\"=\" * 40)\n",
323
+ "for check, passed in checks.items():\n",
324
+ " icon = '✅' if passed else '❌'\n",
325
+ " print(f\" {icon} {check}\")\n",
326
+ "\n",
327
+ "passed = sum(checks.values())\n",
328
+ "print(f\"\\n {passed}/{len(checks)} checks passed\")\n",
329
+ "if passed == len(checks):\n",
330
+ " print(\" Ready for training!\")\n",
331
+ "else:\n",
332
+ " print(\" Re-record or collect more data.\")"
333
+ ]
334
+ },
335
+ {
336
+ "cell_type": "markdown",
337
+ "metadata": {},
338
+ "source": [
339
+ "## 8. Merge Multiple Sessions (Optional)\n",
340
+ "Run this if you have multiple `.npz` files from different team members."
341
+ ]
342
+ },
343
+ {
344
+ "cell_type": "code",
345
+ "execution_count": null,
346
+ "metadata": {},
347
+ "outputs": [],
348
+ "source": [
349
+ "COLLECTED_DIR = \"data_preparation/collected/\"\n",
350
+ "\n",
351
+ "all_features = []\n",
352
+ "all_labels = []\n",
353
+ "all_participants = [] # for participant-aware splitting\n",
354
+ "\n",
355
+ "npz_files = sorted([f for f in os.listdir(COLLECTED_DIR) if f.endswith('.npz')])\n",
356
+ "print(f\"Found {len(npz_files)} .npz files:\\n\")\n",
357
+ "\n",
358
+ "for i, fname in enumerate(npz_files):\n",
359
+ " d = np.load(os.path.join(COLLECTED_DIR, fname), allow_pickle=True)\n",
360
+ " f, l = d['features'], d['labels']\n",
361
+ " n = len(l)\n",
362
+ " n1 = int((l == 1).sum())\n",
363
+ " trans = int(np.sum(np.diff(l) != 0))\n",
364
+ " print(f\" [{i}] {fname}\")\n",
365
+ " print(f\" {n} samples, {n1/n*100:.0f}% focused, {trans} transitions, {n/30:.0f}s\")\n",
366
+ " \n",
367
+ " all_features.append(f)\n",
368
+ " all_labels.append(l)\n",
369
+ " all_participants.append(np.full(n, i, dtype=np.int32))\n",
370
+ "\n",
371
+ "if len(all_features) > 0:\n",
372
+ " merged_features = np.concatenate(all_features)\n",
373
+ " merged_labels = np.concatenate(all_labels)\n",
374
+ " merged_participants = np.concatenate(all_participants)\n",
375
+ " \n",
376
+ " print(f\"\\nMerged: {len(merged_labels)} total samples\")\n",
377
+ " print(f\" Focused: {int((merged_labels==1).sum())} ({(merged_labels==1).mean()*100:.1f}%)\")\n",
378
+ " print(f\" Unfocused: {int((merged_labels==0).sum())} ({(merged_labels==0).mean()*100:.1f}%)\")\n",
379
+ " \n",
380
+ " # Save merged\n",
381
+ " out_path = os.path.join(COLLECTED_DIR, \"merged_all.npz\")\n",
382
+ " np.savez(out_path,\n",
383
+ " features=merged_features,\n",
384
+ " labels=merged_labels,\n",
385
+ " participants=merged_participants,\n",
386
+ " feature_names=d['feature_names'])\n",
387
+ " print(f\" Saved -> {out_path}\")\n",
388
+ "else:\n",
389
+ " print(\"No .npz files found\")"
390
+ ]
391
+ }
392
+ ],
393
+ "metadata": {
394
+ "kernelspec": {
395
+ "display_name": "venv",
396
+ "language": "python",
397
+ "name": "python3"
398
+ },
399
+ "language_info": {
400
+ "codemirror_mode": {
401
+ "name": "ipython",
402
+ "version": 3
403
+ },
404
+ "file_extension": ".py",
405
+ "mimetype": "text/x-python",
406
+ "name": "python",
407
+ "nbconvert_exporter": "python",
408
+ "pygments_lexer": "ipython3",
409
+ "version": "3.13.7"
410
+ }
411
+ },
412
+ "nbformat": 4,
413
+ "nbformat_minor": 4
414
+ }
models/attention_model/collect_features.py CHANGED
@@ -1 +1,403 @@
1
- # stub
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Collect labeled face mesh features from webcam for training
2
+ #
3
+ # Run the demo, press 1 = focused, 0 = not focused, p = pause, q = save & quit.
4
+ # Each labeled frame saves 17 features (geometric + temporal) + label.
5
+ # Expect 5-10 min per person. Switch focus/unfocus every 10-30 seconds.
6
+ #
7
+ # Usage:
8
+ # python models/attention_model/collect_features.py
9
+ # python models/attention_model/collect_features.py --name alice --duration 600
10
+
11
+ import argparse
12
+ import collections
13
+ import math
14
+ import os
15
+ import sys
16
+ import time
17
+
18
+ import cv2
19
+ import numpy as np
20
+
21
+ _PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
22
+ if _PROJECT_ROOT not in sys.path:
23
+ sys.path.insert(0, _PROJECT_ROOT)
24
+
25
+ from models.face_mesh.face_mesh import FaceMeshDetector
26
+ from models.face_orientation.head_pose import HeadPoseEstimator
27
+ from models.eye_behaviour.eye_scorer import EyeBehaviourScorer, compute_gaze_ratio, compute_mar
28
+
29
+ FONT = cv2.FONT_HERSHEY_SIMPLEX
30
+ GREEN = (0, 255, 0)
31
+ RED = (0, 0, 255)
32
+ WHITE = (255, 255, 255)
33
+ YELLOW = (0, 255, 255)
34
+ ORANGE = (0, 165, 255)
35
+ GRAY = (120, 120, 120)
36
+
37
+ # ---------------------------------------------------------------------------
38
+ # 17 features: geometric (11) + derived (2) + temporal (4)
39
+ # ---------------------------------------------------------------------------
40
+ FEATURE_NAMES = [
41
+ # --- geometric (from landmarks each frame) ---
42
+ "ear_left", # 0 Left Eye Aspect Ratio
43
+ "ear_right", # 1 Right Eye Aspect Ratio
44
+ "ear_avg", # 2 Mean EAR
45
+ "h_gaze", # 3 Horizontal iris position
46
+ "v_gaze", # 4 Vertical iris position
47
+ "mar", # 5 Mouth Aspect Ratio
48
+ "yaw", # 6 Head horizontal rotation (degrees)
49
+ "pitch", # 7 Head vertical tilt (degrees)
50
+ "roll", # 8 Head lateral tilt (degrees)
51
+ "s_face", # 9 Cosine-decay head pose score [0,1]
52
+ "s_eye", # 10 Geometric eye score [0,1]
53
+ # --- derived ---
54
+ "gaze_offset", # 11 Distance from gaze centre: sqrt((h-0.5)^2 + (v-0.5)^2)
55
+ "head_deviation", # 12 sqrt(yaw^2 + pitch^2)
56
+ # --- temporal (rolling window) ---
57
+ "perclos", # 13 % eye closure over last 60 frames
58
+ "blink_rate", # 14 Blinks per minute (30s window)
59
+ "closure_duration", # 15 Current sustained eye closure (seconds)
60
+ "yawn_duration", # 16 Current sustained yawn (seconds)
61
+ ]
62
+
63
+ NUM_FEATURES = len(FEATURE_NAMES)
64
+ assert NUM_FEATURES == 17
65
+
66
+
67
+ # ---------------------------------------------------------------------------
68
+ # Temporal tracker — keeps rolling history for PERCLOS, blink rate, etc.
69
+ # ---------------------------------------------------------------------------
70
+ class TemporalTracker:
71
+ """Track temporal signals across frames."""
72
+
73
+ EAR_BLINK_THRESH = 0.21 # EAR below this = eyes closed
74
+ MAR_YAWN_THRESH = 0.04 # MAR above this = yawning
75
+ PERCLOS_WINDOW = 60 # frames for PERCLOS
76
+ BLINK_WINDOW_SEC = 30.0 # seconds for blink rate
77
+
78
+ def __init__(self):
79
+ self.ear_history = collections.deque(maxlen=self.PERCLOS_WINDOW)
80
+ self.blink_timestamps = collections.deque() # list of blink end times
81
+ self._eyes_closed = False
82
+ self._closure_start = None # time when eyes first closed
83
+ self._yawn_start = None # time when yawn started
84
+
85
+ def update(self, ear_avg, mar, now=None):
86
+ """Call once per frame. Returns (perclos, blink_rate, closure_dur, yawn_dur)."""
87
+ if now is None:
88
+ now = time.time()
89
+
90
+ # --- PERCLOS ---
91
+ closed = ear_avg < self.EAR_BLINK_THRESH
92
+ self.ear_history.append(1.0 if closed else 0.0)
93
+ perclos = sum(self.ear_history) / len(self.ear_history) if self.ear_history else 0.0
94
+
95
+ # --- Blink detection (closed -> open transition) ---
96
+ if self._eyes_closed and not closed:
97
+ # blink just ended
98
+ self.blink_timestamps.append(now)
99
+ self._eyes_closed = closed
100
+
101
+ # prune old blinks
102
+ cutoff = now - self.BLINK_WINDOW_SEC
103
+ while self.blink_timestamps and self.blink_timestamps[0] < cutoff:
104
+ self.blink_timestamps.popleft()
105
+ blink_rate = len(self.blink_timestamps) * (60.0 / self.BLINK_WINDOW_SEC)
106
+
107
+ # --- Closure duration ---
108
+ if closed:
109
+ if self._closure_start is None:
110
+ self._closure_start = now
111
+ closure_dur = now - self._closure_start
112
+ else:
113
+ self._closure_start = None
114
+ closure_dur = 0.0
115
+
116
+ # --- Yawn duration ---
117
+ yawning = mar > self.MAR_YAWN_THRESH
118
+ if yawning:
119
+ if self._yawn_start is None:
120
+ self._yawn_start = now
121
+ yawn_dur = now - self._yawn_start
122
+ else:
123
+ self._yawn_start = None
124
+ yawn_dur = 0.0
125
+
126
+ return perclos, blink_rate, closure_dur, yawn_dur
127
+
128
+
129
+ # ---------------------------------------------------------------------------
130
+ # Feature extraction (one frame -> 17-dim vector)
131
+ # ---------------------------------------------------------------------------
132
+ def extract_features(landmarks, w, h, head_pose, eye_scorer, temporal):
133
+ """Extract 17 features from one frame's landmarks."""
134
+ from models.eye_behaviour.eye_scorer import _LEFT_EYE_EAR, _RIGHT_EYE_EAR, compute_ear
135
+
136
+ # --- geometric ---
137
+ ear_left = compute_ear(landmarks, _LEFT_EYE_EAR)
138
+ ear_right = compute_ear(landmarks, _RIGHT_EYE_EAR)
139
+ ear_avg = (ear_left + ear_right) / 2.0
140
+ h_gaze, v_gaze = compute_gaze_ratio(landmarks)
141
+ mar = compute_mar(landmarks)
142
+
143
+ angles = head_pose.estimate(landmarks, w, h)
144
+ yaw = angles[0] if angles else 0.0
145
+ pitch = angles[1] if angles else 0.0
146
+ roll = angles[2] if angles else 0.0
147
+
148
+ s_face = head_pose.score(landmarks, w, h)
149
+ s_eye = eye_scorer.score(landmarks)
150
+
151
+ # --- derived ---
152
+ gaze_offset = math.sqrt((h_gaze - 0.5) ** 2 + (v_gaze - 0.5) ** 2)
153
+ head_deviation = math.sqrt(yaw ** 2 + pitch ** 2)
154
+
155
+ # --- temporal ---
156
+ perclos, blink_rate, closure_dur, yawn_dur = temporal.update(ear_avg, mar)
157
+
158
+ return np.array([
159
+ ear_left, ear_right, ear_avg,
160
+ h_gaze, v_gaze,
161
+ mar,
162
+ yaw, pitch, roll,
163
+ s_face, s_eye,
164
+ gaze_offset,
165
+ head_deviation,
166
+ perclos, blink_rate, closure_dur, yawn_dur,
167
+ ], dtype=np.float32)
168
+
169
+
170
+ # ---------------------------------------------------------------------------
171
+ # Quality checks — run at save time
172
+ # ---------------------------------------------------------------------------
173
+ def quality_report(labels):
174
+ """Print warnings about data quality issues."""
175
+ n = len(labels)
176
+ n1 = int((labels == 1).sum())
177
+ n0 = n - n1
178
+ transitions = int(np.sum(np.diff(labels) != 0))
179
+ duration_sec = n / 30.0 # approximate at 30fps
180
+
181
+ warnings = []
182
+
183
+ print(f"\n{'='*50}")
184
+ print(f" DATA QUALITY REPORT")
185
+ print(f"{'='*50}")
186
+ print(f" Total samples : {n}")
187
+ print(f" Focused : {n1} ({n1/max(n,1)*100:.1f}%)")
188
+ print(f" Unfocused : {n0} ({n0/max(n,1)*100:.1f}%)")
189
+ print(f" Duration : {duration_sec:.0f}s ({duration_sec/60:.1f} min)")
190
+ print(f" Transitions : {transitions}")
191
+ if transitions > 0:
192
+ print(f" Avg segment : {n/transitions:.0f} frames ({n/transitions/30:.1f}s)")
193
+
194
+ # checks
195
+ if duration_sec < 120:
196
+ warnings.append(f"TOO SHORT: {duration_sec:.0f}s — aim for 5-10 minutes (300-600s)")
197
+
198
+ if n < 3000:
199
+ warnings.append(f"LOW SAMPLE COUNT: {n} frames — aim for 9000+ (5 min at 30fps)")
200
+
201
+ balance = n1 / max(n, 1)
202
+ if balance < 0.3 or balance > 0.7:
203
+ warnings.append(f"IMBALANCED: {balance:.0%} focused — aim for 35-65% focused")
204
+
205
+ if transitions < 10:
206
+ warnings.append(f"TOO FEW TRANSITIONS: {transitions} — switch every 10-30s, aim for 20+")
207
+
208
+ if transitions == 1:
209
+ warnings.append("SINGLE BLOCK: you recorded one unfocused + one focused block — "
210
+ "model will learn temporal position, not focus patterns")
211
+
212
+ if warnings:
213
+ print(f"\n ⚠️ WARNINGS ({len(warnings)}):")
214
+ for w in warnings:
215
+ print(f" • {w}")
216
+ print(f"\n Consider re-recording this session.")
217
+ else:
218
+ print(f"\n ✅ All checks passed!")
219
+
220
+ print(f"{'='*50}\n")
221
+ return len(warnings) == 0
222
+
223
+
224
+ # ---------------------------------------------------------------------------
225
+ # Main
226
+ # ---------------------------------------------------------------------------
227
+ def main():
228
+ parser = argparse.ArgumentParser(description="Collect labeled attention data from webcam")
229
+ parser.add_argument("--name", type=str, default="session",
230
+ help="Your name or session ID")
231
+ parser.add_argument("--camera", type=int, default=0,
232
+ help="Camera index")
233
+ parser.add_argument("--duration", type=int, default=600,
234
+ help="Max recording time (seconds, default 10 min)")
235
+ parser.add_argument("--output-dir", type=str,
236
+ default=os.path.join(_PROJECT_ROOT, "data_preparation", "collected"),
237
+ help="Where to save .npz files")
238
+ args = parser.parse_args()
239
+
240
+ os.makedirs(args.output_dir, exist_ok=True)
241
+
242
+ detector = FaceMeshDetector()
243
+ head_pose = HeadPoseEstimator()
244
+ eye_scorer = EyeBehaviourScorer()
245
+ temporal = TemporalTracker()
246
+
247
+ cap = cv2.VideoCapture(args.camera)
248
+ if not cap.isOpened():
249
+ print("[COLLECT] ERROR: can't open camera")
250
+ return
251
+
252
+ print("[COLLECT] Data Collection Tool")
253
+ print(f"[COLLECT] Session: {args.name}, max {args.duration}s")
254
+ print(f"[COLLECT] Features per frame: {NUM_FEATURES}")
255
+ print("[COLLECT] Controls:")
256
+ print(" 1 = FOCUSED (looking at screen normally)")
257
+ print(" 0 = NOT FOCUSED (phone, away, eyes closed, yawning)")
258
+ print(" p = pause")
259
+ print(" q = save & quit")
260
+ print()
261
+ print("[COLLECT] TIPS for good data:")
262
+ print(" • Switch between 1 and 0 every 10-30 seconds")
263
+ print(" • Aim for 20+ transitions total")
264
+ print(" • Act out varied scenarios: reading, phone, talking, drowsy")
265
+ print(" • Record at least 5 minutes")
266
+ print()
267
+
268
+ features_list = []
269
+ labels_list = []
270
+ label = None # None = paused
271
+ transitions = 0 # count label switches
272
+ prev_label = None
273
+ status = "PAUSED -- press 1 (focused) or 0 (not focused)"
274
+ t_start = time.time()
275
+ prev_time = time.time()
276
+ fps = 0.0
277
+
278
+ try:
279
+ while True:
280
+ elapsed = time.time() - t_start
281
+ if elapsed > args.duration:
282
+ print(f"[COLLECT] Time limit ({args.duration}s)")
283
+ break
284
+
285
+ ret, frame = cap.read()
286
+ if not ret:
287
+ break
288
+
289
+ h, w = frame.shape[:2]
290
+ landmarks = detector.process(frame)
291
+ face_ok = landmarks is not None
292
+
293
+ # record if labeling + face visible
294
+ if face_ok and label is not None:
295
+ vec = extract_features(landmarks, w, h, head_pose, eye_scorer, temporal)
296
+ features_list.append(vec)
297
+ labels_list.append(label)
298
+
299
+ # count transitions
300
+ if prev_label is not None and label != prev_label:
301
+ transitions += 1
302
+ prev_label = label
303
+
304
+ now = time.time()
305
+ fps = 0.9 * fps + 0.1 * (1.0 / max(now - prev_time, 1e-6))
306
+ prev_time = now
307
+
308
+ # --- draw UI ---
309
+ n = len(labels_list)
310
+ n1 = sum(1 for x in labels_list if x == 1)
311
+ n0 = n - n1
312
+ remaining = max(0, args.duration - elapsed)
313
+
314
+ # top bar
315
+ bar_color = GREEN if label == 1 else (RED if label == 0 else (80, 80, 80))
316
+ cv2.rectangle(frame, (0, 0), (w, 70), (0, 0, 0), -1)
317
+ cv2.putText(frame, status, (10, 22), FONT, 0.55, bar_color, 2, cv2.LINE_AA)
318
+ cv2.putText(frame, f"Samples: {n} (F:{n1} U:{n0}) Switches: {transitions}",
319
+ (10, 48), FONT, 0.42, WHITE, 1, cv2.LINE_AA)
320
+ cv2.putText(frame, f"FPS:{fps:.0f}", (w - 80, 22), FONT, 0.45, WHITE, 1, cv2.LINE_AA)
321
+ cv2.putText(frame, f"{int(remaining)}s left", (w - 80, 48), FONT, 0.42, YELLOW, 1, cv2.LINE_AA)
322
+
323
+ # balance bar
324
+ if n > 0:
325
+ bar_w = min(w - 20, 300)
326
+ bar_x = w - bar_w - 10
327
+ bar_y = 58
328
+ frac = n1 / n
329
+ cv2.rectangle(frame, (bar_x, bar_y), (bar_x + bar_w, bar_y + 8), (40, 40, 40), -1)
330
+ cv2.rectangle(frame, (bar_x, bar_y), (bar_x + int(bar_w * frac), bar_y + 8), GREEN, -1)
331
+ cv2.putText(frame, f"{frac:.0%}F", (bar_x + bar_w + 4, bar_y + 8),
332
+ FONT, 0.3, GRAY, 1, cv2.LINE_AA)
333
+
334
+ if not face_ok:
335
+ cv2.putText(frame, "NO FACE", (w // 2 - 60, h // 2), FONT, 0.7, RED, 2, cv2.LINE_AA)
336
+
337
+ # red dot = recording
338
+ if label is not None and face_ok:
339
+ cv2.circle(frame, (w - 20, 80), 8, RED, -1)
340
+
341
+ # live warnings
342
+ warn_y = h - 35
343
+ if n > 100 and transitions < 3:
344
+ cv2.putText(frame, "! Switch more often (aim for 20+ transitions)",
345
+ (10, warn_y), FONT, 0.38, ORANGE, 1, cv2.LINE_AA)
346
+ warn_y -= 18
347
+ if elapsed > 30 and n > 0:
348
+ bal = n1 / n
349
+ if bal < 0.25 or bal > 0.75:
350
+ cv2.putText(frame, f"! Imbalanced ({bal:.0%} focused) - record more of the other",
351
+ (10, warn_y), FONT, 0.38, ORANGE, 1, cv2.LINE_AA)
352
+ warn_y -= 18
353
+
354
+ cv2.putText(frame, "1:focused 0:unfocused p:pause q:save+quit",
355
+ (10, h - 10), FONT, 0.38, GRAY, 1, cv2.LINE_AA)
356
+
357
+ cv2.imshow("FocusGuard -- Data Collection", frame)
358
+
359
+ key = cv2.waitKey(1) & 0xFF
360
+ if key == ord("1"):
361
+ label = 1
362
+ status = "Recording: FOCUSED"
363
+ print(f"[COLLECT] -> FOCUSED (n={n}, transitions={transitions})")
364
+ elif key == ord("0"):
365
+ label = 0
366
+ status = "Recording: NOT FOCUSED"
367
+ print(f"[COLLECT] -> NOT FOCUSED (n={n}, transitions={transitions})")
368
+ elif key == ord("p"):
369
+ label = None
370
+ status = "PAUSED"
371
+ print(f"[COLLECT] paused (n={n})")
372
+ elif key == ord("q"):
373
+ break
374
+
375
+ finally:
376
+ cap.release()
377
+ cv2.destroyAllWindows()
378
+ detector.close()
379
+
380
+ if len(features_list) > 0:
381
+ feats = np.stack(features_list)
382
+ labs = np.array(labels_list, dtype=np.int64)
383
+
384
+ ts = time.strftime("%Y%m%d_%H%M%S")
385
+ fname = f"{args.name}_{ts}.npz"
386
+ fpath = os.path.join(args.output_dir, fname)
387
+ np.savez(fpath,
388
+ features=feats,
389
+ labels=labs,
390
+ feature_names=np.array(FEATURE_NAMES))
391
+
392
+ print(f"\n[COLLECT] Saved {len(labs)} samples -> {fpath}")
393
+ print(f" Shape: {feats.shape} ({NUM_FEATURES} features)")
394
+
395
+ quality_report(labs)
396
+ else:
397
+ print("\n[COLLECT] No data collected")
398
+
399
+ print("[COLLECT] Done")
400
+
401
+
402
+ if __name__ == "__main__":
403
+ main()