BcantCode commited on
Commit
687b215
·
verified ·
1 Parent(s): a0e0750

Upload src/model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/model.py +363 -0
src/model.py ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GazeInception-Lite: Gated Inception Model for Mobile Eye Gaze Estimation
3
+
4
+ Architecture:
5
+ - Input: 64x64 RGB eye crop (left + right eye stacked as 2-channel or 128x64 side-by-side)
6
+ - Gated Inception Blocks: Each inception block has a lightweight gate (squeeze-excitation style)
7
+ that learns to skip branches that contribute little, reducing useless compute
8
+ - Multi-scale feature extraction via inception (1x1, 3x3, 5x5 parallel convolutions)
9
+ - Coordinate Attention for spatial awareness
10
+ - Output: (x, y) screen coordinates normalized to [0, 1]
11
+
12
+ Design goals:
13
+ - < 500K parameters for fast mobile inference
14
+ - TFLite compatible (no unsupported ops)
15
+ - Works in dark (trained with illumination augmentation)
16
+ - Handles glasses (trained with glasses augmentation)
17
+ - Handles lazy eye / strabismus (trained with per-eye asymmetric augmentation)
18
+ """
19
+
20
+ import tensorflow as tf
21
+ from tensorflow import keras
22
+ from tensorflow.keras import layers, Model
23
+ import numpy as np
24
+
25
+
26
+ class GatedInceptionBlock(layers.Layer):
27
+ """
28
+ Inception block with gating mechanism.
29
+
30
+ The gate is a lightweight learned sigmoid that scales each inception branch.
31
+ Branches with low gate values contribute near-zero, effectively being "skipped"
32
+ at inference — reducing useless compute via learned conditional computation.
33
+
34
+ Branches:
35
+ 1. 1x1 conv (point features)
36
+ 2. 1x1 -> 3x3 conv (local features)
37
+ 3. 1x1 -> 5x5 depthwise separable conv (wider context)
38
+ 4. 3x3 max pool -> 1x1 conv (pooled features)
39
+
40
+ Gate: Global Average Pool -> Dense -> Sigmoid per branch
41
+ """
42
+
43
+ def __init__(self, filters_1x1, filters_3x3_reduce, filters_3x3,
44
+ filters_5x5_reduce, filters_5x5, filters_pool, **kwargs):
45
+ super().__init__(**kwargs)
46
+ self.filters_1x1 = filters_1x1
47
+ self.filters_3x3 = filters_3x3
48
+ self.filters_5x5 = filters_5x5
49
+ self.filters_pool = filters_pool
50
+ self.num_branches = 4
51
+
52
+ # Branch 1: 1x1
53
+ self.branch1_conv = layers.Conv2D(filters_1x1, 1, padding='same', use_bias=False)
54
+ self.branch1_bn = layers.BatchNormalization()
55
+
56
+ # Branch 2: 1x1 -> 3x3
57
+ self.branch2_reduce = layers.Conv2D(filters_3x3_reduce, 1, padding='same', use_bias=False)
58
+ self.branch2_reduce_bn = layers.BatchNormalization()
59
+ self.branch2_conv = layers.DepthwiseConv2D(3, padding='same', use_bias=False)
60
+ self.branch2_pw = layers.Conv2D(filters_3x3, 1, padding='same', use_bias=False)
61
+ self.branch2_bn = layers.BatchNormalization()
62
+
63
+ # Branch 3: 1x1 -> 5x5 depthwise separable
64
+ self.branch3_reduce = layers.Conv2D(filters_5x5_reduce, 1, padding='same', use_bias=False)
65
+ self.branch3_reduce_bn = layers.BatchNormalization()
66
+ self.branch3_dw = layers.DepthwiseConv2D(5, padding='same', use_bias=False)
67
+ self.branch3_pw = layers.Conv2D(filters_5x5, 1, padding='same', use_bias=False)
68
+ self.branch3_bn = layers.BatchNormalization()
69
+
70
+ # Branch 4: MaxPool -> 1x1
71
+ self.branch4_pool = layers.MaxPooling2D(3, strides=1, padding='same')
72
+ self.branch4_conv = layers.Conv2D(filters_pool, 1, padding='same', use_bias=False)
73
+ self.branch4_bn = layers.BatchNormalization()
74
+
75
+ # Gating mechanism: learns to weight each branch
76
+ total_filters = filters_1x1 + filters_3x3 + filters_5x5 + filters_pool
77
+ self.gate_pool = layers.GlobalAveragePooling2D()
78
+ self.gate_dense1 = layers.Dense(self.num_branches * 4, activation='relu')
79
+ self.gate_dense2 = layers.Dense(self.num_branches, activation='sigmoid')
80
+
81
+ # Final activation
82
+ self.relu = layers.ReLU()
83
+
84
+ def call(self, x, training=False):
85
+ # Compute gate values (which branches to activate)
86
+ gate_input = self.gate_pool(x)
87
+ gate = self.gate_dense1(gate_input)
88
+ gate = self.gate_dense2(gate) # [batch, 4] sigmoid values
89
+
90
+ # Branch 1
91
+ b1 = self.branch1_conv(x)
92
+ b1 = self.branch1_bn(b1, training=training)
93
+ b1 = self.relu(b1)
94
+
95
+ # Branch 2
96
+ b2 = self.branch2_reduce(x)
97
+ b2 = self.branch2_reduce_bn(b2, training=training)
98
+ b2 = self.relu(b2)
99
+ b2 = self.branch2_conv(b2)
100
+ b2 = self.branch2_pw(b2)
101
+ b2 = self.branch2_bn(b2, training=training)
102
+ b2 = self.relu(b2)
103
+
104
+ # Branch 3
105
+ b3 = self.branch3_reduce(x)
106
+ b3 = self.branch3_reduce_bn(b3, training=training)
107
+ b3 = self.relu(b3)
108
+ b3 = self.branch3_dw(b3)
109
+ b3 = self.branch3_pw(b3)
110
+ b3 = self.branch3_bn(b3, training=training)
111
+ b3 = self.relu(b3)
112
+
113
+ # Branch 4
114
+ b4 = self.branch4_pool(x)
115
+ b4 = self.branch4_conv(b4)
116
+ b4 = self.branch4_bn(b4, training=training)
117
+ b4 = self.relu(b4)
118
+
119
+ # Apply gates: multiply each branch by its gate scalar
120
+ # gate[:, i] is a scalar per sample - reshape for broadcasting
121
+ g1 = tf.reshape(gate[:, 0], [-1, 1, 1, 1])
122
+ g2 = tf.reshape(gate[:, 1], [-1, 1, 1, 1])
123
+ g3 = tf.reshape(gate[:, 2], [-1, 1, 1, 1])
124
+ g4 = tf.reshape(gate[:, 3], [-1, 1, 1, 1])
125
+
126
+ b1 = b1 * g1
127
+ b2 = b2 * g2
128
+ b3 = b3 * g3
129
+ b4 = b4 * g4
130
+
131
+ # Concatenate gated branches
132
+ return tf.concat([b1, b2, b3, b4], axis=-1)
133
+
134
+ def get_config(self):
135
+ config = super().get_config()
136
+ config.update({
137
+ 'filters_1x1': self.filters_1x1,
138
+ 'filters_3x3_reduce': self.branch2_reduce.filters if hasattr(self.branch2_reduce, 'filters') else 0,
139
+ 'filters_3x3': self.filters_3x3,
140
+ 'filters_5x5_reduce': self.branch3_reduce.filters if hasattr(self.branch3_reduce, 'filters') else 0,
141
+ 'filters_5x5': self.filters_5x5,
142
+ 'filters_pool': self.filters_pool,
143
+ })
144
+ return config
145
+
146
+
147
+ class CoordinateAttention(layers.Layer):
148
+ """
149
+ Coordinate Attention module (Hou et al. 2021).
150
+ Encodes spatial position info into channel attention for better localization.
151
+ Critical for gaze estimation where spatial position of iris matters.
152
+ """
153
+
154
+ def __init__(self, reduction_ratio=4, **kwargs):
155
+ super().__init__(**kwargs)
156
+ self.reduction_ratio = reduction_ratio
157
+
158
+ def build(self, input_shape):
159
+ channels = input_shape[-1]
160
+ reduced_channels = max(channels // self.reduction_ratio, 8)
161
+
162
+ self.pool_h = layers.Lambda(lambda x: tf.reduce_mean(x, axis=2, keepdims=True))
163
+ self.pool_w = layers.Lambda(lambda x: tf.reduce_mean(x, axis=1, keepdims=True))
164
+
165
+ self.conv_reduce = layers.Conv2D(reduced_channels, 1, use_bias=False)
166
+ self.bn = layers.BatchNormalization()
167
+ self.relu = layers.ReLU()
168
+
169
+ self.conv_h = layers.Conv2D(channels, 1, activation='sigmoid')
170
+ self.conv_w = layers.Conv2D(channels, 1, activation='sigmoid')
171
+
172
+ super().build(input_shape)
173
+
174
+ def call(self, x, training=False):
175
+ # Pool along width (keep height)
176
+ h_att = self.pool_h(x) # [B, H, 1, C]
177
+ # Pool along height (keep width)
178
+ w_att = self.pool_w(x) # [B, 1, W, C]
179
+
180
+ # Transpose w_att to match h_att shape for concatenation
181
+ w_att_t = tf.transpose(w_att, perm=[0, 2, 1, 3]) # [B, W, 1, C]
182
+
183
+ # Concatenate and reduce
184
+ combined = tf.concat([h_att, w_att_t], axis=1) # [B, H+W, 1, C]
185
+ combined = self.conv_reduce(combined)
186
+ combined = self.bn(combined, training=training)
187
+ combined = self.relu(combined)
188
+
189
+ # Split back
190
+ h_len = tf.shape(h_att)[1]
191
+ w_len = tf.shape(w_att_t)[1]
192
+
193
+ h_out = combined[:, :h_len, :, :]
194
+ w_out = combined[:, h_len:, :, :]
195
+
196
+ # Generate attention maps
197
+ h_att_map = self.conv_h(h_out) # [B, H, 1, C]
198
+ w_att_map = self.conv_w(w_out) # [B, W, 1, C]
199
+ w_att_map = tf.transpose(w_att_map, perm=[0, 2, 1, 3]) # [B, 1, W, C]
200
+
201
+ # Apply attention
202
+ return x * h_att_map * w_att_map
203
+
204
+
205
+ def build_gaze_inception_lite(input_shape=(64, 64, 3), num_outputs=2):
206
+ """
207
+ Build the GazeInception-Lite model.
208
+
209
+ Architecture:
210
+ Input (64x64x3) -> Stem -> GatedInception1 -> GatedInception2 ->
211
+ CoordAttention -> GatedInception3 -> GlobalPool -> Dense -> (x, y)
212
+
213
+ Total: ~350K parameters
214
+ """
215
+ inputs = layers.Input(shape=input_shape, name='eye_image')
216
+
217
+ # Stem: lightweight feature extraction
218
+ x = layers.Conv2D(32, 3, strides=2, padding='same', use_bias=False)(inputs) # 32x32
219
+ x = layers.BatchNormalization()(x)
220
+ x = layers.ReLU()(x)
221
+ x = layers.Conv2D(32, 3, padding='same', use_bias=False)(x) # 32x32
222
+ x = layers.BatchNormalization()(x)
223
+ x = layers.ReLU()(x)
224
+
225
+ # Gated Inception Block 1 (32x32 -> 16x16)
226
+ x = GatedInceptionBlock(
227
+ filters_1x1=16,
228
+ filters_3x3_reduce=16, filters_3x3=24,
229
+ filters_5x5_reduce=8, filters_5x5=12,
230
+ filters_pool=12,
231
+ name='gated_inception_1'
232
+ )(x) # output: 64 channels
233
+ x = layers.MaxPooling2D(2)(x) # 16x16
234
+
235
+ # Gated Inception Block 2 (16x16 -> 8x8)
236
+ x = GatedInceptionBlock(
237
+ filters_1x1=32,
238
+ filters_3x3_reduce=24, filters_3x3=48,
239
+ filters_5x5_reduce=12, filters_5x5=24,
240
+ filters_pool=24,
241
+ name='gated_inception_2'
242
+ )(x) # output: 128 channels
243
+ x = layers.MaxPooling2D(2)(x) # 8x8
244
+
245
+ # Coordinate Attention - encodes spatial position for gaze direction
246
+ x = CoordinateAttention(reduction_ratio=4, name='coord_attention')(x)
247
+
248
+ # Gated Inception Block 3 (8x8 -> 4x4)
249
+ x = GatedInceptionBlock(
250
+ filters_1x1=48,
251
+ filters_3x3_reduce=32, filters_3x3=64,
252
+ filters_5x5_reduce=16, filters_5x5=32,
253
+ filters_pool=32,
254
+ name='gated_inception_3'
255
+ )(x) # output: 176 channels
256
+ x = layers.MaxPooling2D(2)(x) # 4x4
257
+
258
+ # Global feature aggregation
259
+ x = layers.GlobalAveragePooling2D()(x)
260
+
261
+ # Regression head
262
+ x = layers.Dense(128, activation='relu')(x)
263
+ x = layers.Dropout(0.3)(x)
264
+ x = layers.Dense(64, activation='relu')(x)
265
+ x = layers.Dropout(0.2)(x)
266
+
267
+ # Output: (x, y) screen coordinates in [0, 1]
268
+ outputs = layers.Dense(num_outputs, activation='sigmoid', name='gaze_coords')(x)
269
+
270
+ model = Model(inputs=inputs, outputs=outputs, name='GazeInceptionLite')
271
+ return model
272
+
273
+
274
+ def build_dual_eye_model(eye_shape=(64, 64, 3), face_shape=(64, 64, 3), num_outputs=2):
275
+ """
276
+ Full model with dual eye inputs + face context.
277
+
278
+ This handles lazy eye by processing each eye independently through
279
+ shared-weight gated inception, then combining with face features.
280
+ Each eye gets its own gaze features, and the model learns to handle
281
+ asymmetric eye conditions (strabismus/amblyopia).
282
+
283
+ Inputs:
284
+ - left_eye: 64x64x3 crop
285
+ - right_eye: 64x64x3 crop
286
+ - face: 64x64x3 crop (provides head pose context)
287
+
288
+ Output:
289
+ - (x, y) normalized screen coordinates
290
+ """
291
+ left_eye_input = layers.Input(shape=eye_shape, name='left_eye')
292
+ right_eye_input = layers.Input(shape=eye_shape, name='right_eye')
293
+ face_input = layers.Input(shape=face_shape, name='face')
294
+
295
+ # Shared eye feature extractor (gated inception backbone)
296
+ eye_backbone = build_gaze_inception_lite(input_shape=eye_shape, num_outputs=2)
297
+ # Get features from the GlobalAveragePooling layer (before dense head)
298
+ # Find the GlobalAveragePooling2D layer
299
+ gap_layer = None
300
+ for layer in eye_backbone.layers:
301
+ if isinstance(layer, layers.GlobalAveragePooling2D):
302
+ gap_layer = layer
303
+ eye_feature_extractor = Model(
304
+ inputs=eye_backbone.input,
305
+ outputs=gap_layer.output,
306
+ name='eye_feature_extractor'
307
+ )
308
+
309
+ # Extract features for each eye independently (shared weights)
310
+ left_features = eye_feature_extractor(left_eye_input) # [B, 176]
311
+ right_features = eye_feature_extractor(right_eye_input) # [B, 176]
312
+
313
+ # Lightweight face context extractor (head pose proxy)
314
+ f = layers.Conv2D(16, 3, strides=2, padding='same', activation='relu')(face_input)
315
+ f = layers.Conv2D(32, 3, strides=2, padding='same', activation='relu')(f)
316
+ f = layers.Conv2D(32, 3, strides=2, padding='same', activation='relu')(f)
317
+ f = layers.GlobalAveragePooling2D()(f)
318
+ face_features = layers.Dense(64, activation='relu')(f) # [B, 64]
319
+
320
+ # Combine: left_eye + right_eye + face
321
+ # The model learns eye asymmetry (lazy eye) because eyes are separate inputs
322
+ combined = layers.Concatenate()([left_features, right_features, face_features])
323
+
324
+ # Fusion head
325
+ x = layers.Dense(128, activation='relu')(combined)
326
+ x = layers.Dropout(0.3)(x)
327
+ x = layers.Dense(64, activation='relu')(x)
328
+ x = layers.Dropout(0.2)(x)
329
+ outputs = layers.Dense(num_outputs, activation='sigmoid', name='gaze_coords')(x)
330
+
331
+ model = Model(
332
+ inputs=[left_eye_input, right_eye_input, face_input],
333
+ outputs=outputs,
334
+ name='GazeInceptionLite_DualEye'
335
+ )
336
+ return model
337
+
338
+
339
+ if __name__ == '__main__':
340
+ # Test single eye model
341
+ model_single = build_gaze_inception_lite()
342
+ model_single.summary()
343
+ print(f"\nSingle eye model params: {model_single.count_params():,}")
344
+
345
+ # Test with random input
346
+ test_input = np.random.rand(2, 64, 64, 3).astype(np.float32)
347
+ output = model_single(test_input)
348
+ print(f"Output shape: {output.shape}")
349
+ print(f"Output values: {output.numpy()}")
350
+
351
+ print("\n" + "="*60)
352
+
353
+ # Test dual eye model
354
+ model_dual = build_dual_eye_model()
355
+ model_dual.summary()
356
+ print(f"\nDual eye model params: {model_dual.count_params():,}")
357
+
358
+ test_left = np.random.rand(2, 64, 64, 3).astype(np.float32)
359
+ test_right = np.random.rand(2, 64, 64, 3).astype(np.float32)
360
+ test_face = np.random.rand(2, 64, 64, 3).astype(np.float32)
361
+ output = model_dual([test_left, test_right, test_face])
362
+ print(f"Output shape: {output.shape}")
363
+ print(f"Output values: {output.numpy()}")