Upload src/model.py with huggingface_hub
Browse files- src/model.py +363 -0
src/model.py
ADDED
|
@@ -0,0 +1,363 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
GazeInception-Lite: Gated Inception Model for Mobile Eye Gaze Estimation
|
| 3 |
+
|
| 4 |
+
Architecture:
|
| 5 |
+
- Input: 64x64 RGB eye crop (left + right eye stacked as 2-channel or 128x64 side-by-side)
|
| 6 |
+
- Gated Inception Blocks: Each inception block has a lightweight gate (squeeze-excitation style)
|
| 7 |
+
that learns to skip branches that contribute little, reducing useless compute
|
| 8 |
+
- Multi-scale feature extraction via inception (1x1, 3x3, 5x5 parallel convolutions)
|
| 9 |
+
- Coordinate Attention for spatial awareness
|
| 10 |
+
- Output: (x, y) screen coordinates normalized to [0, 1]
|
| 11 |
+
|
| 12 |
+
Design goals:
|
| 13 |
+
- < 500K parameters for fast mobile inference
|
| 14 |
+
- TFLite compatible (no unsupported ops)
|
| 15 |
+
- Works in dark (trained with illumination augmentation)
|
| 16 |
+
- Handles glasses (trained with glasses augmentation)
|
| 17 |
+
- Handles lazy eye / strabismus (trained with per-eye asymmetric augmentation)
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
import tensorflow as tf
|
| 21 |
+
from tensorflow import keras
|
| 22 |
+
from tensorflow.keras import layers, Model
|
| 23 |
+
import numpy as np
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class GatedInceptionBlock(layers.Layer):
|
| 27 |
+
"""
|
| 28 |
+
Inception block with gating mechanism.
|
| 29 |
+
|
| 30 |
+
The gate is a lightweight learned sigmoid that scales each inception branch.
|
| 31 |
+
Branches with low gate values contribute near-zero, effectively being "skipped"
|
| 32 |
+
at inference — reducing useless compute via learned conditional computation.
|
| 33 |
+
|
| 34 |
+
Branches:
|
| 35 |
+
1. 1x1 conv (point features)
|
| 36 |
+
2. 1x1 -> 3x3 conv (local features)
|
| 37 |
+
3. 1x1 -> 5x5 depthwise separable conv (wider context)
|
| 38 |
+
4. 3x3 max pool -> 1x1 conv (pooled features)
|
| 39 |
+
|
| 40 |
+
Gate: Global Average Pool -> Dense -> Sigmoid per branch
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
def __init__(self, filters_1x1, filters_3x3_reduce, filters_3x3,
|
| 44 |
+
filters_5x5_reduce, filters_5x5, filters_pool, **kwargs):
|
| 45 |
+
super().__init__(**kwargs)
|
| 46 |
+
self.filters_1x1 = filters_1x1
|
| 47 |
+
self.filters_3x3 = filters_3x3
|
| 48 |
+
self.filters_5x5 = filters_5x5
|
| 49 |
+
self.filters_pool = filters_pool
|
| 50 |
+
self.num_branches = 4
|
| 51 |
+
|
| 52 |
+
# Branch 1: 1x1
|
| 53 |
+
self.branch1_conv = layers.Conv2D(filters_1x1, 1, padding='same', use_bias=False)
|
| 54 |
+
self.branch1_bn = layers.BatchNormalization()
|
| 55 |
+
|
| 56 |
+
# Branch 2: 1x1 -> 3x3
|
| 57 |
+
self.branch2_reduce = layers.Conv2D(filters_3x3_reduce, 1, padding='same', use_bias=False)
|
| 58 |
+
self.branch2_reduce_bn = layers.BatchNormalization()
|
| 59 |
+
self.branch2_conv = layers.DepthwiseConv2D(3, padding='same', use_bias=False)
|
| 60 |
+
self.branch2_pw = layers.Conv2D(filters_3x3, 1, padding='same', use_bias=False)
|
| 61 |
+
self.branch2_bn = layers.BatchNormalization()
|
| 62 |
+
|
| 63 |
+
# Branch 3: 1x1 -> 5x5 depthwise separable
|
| 64 |
+
self.branch3_reduce = layers.Conv2D(filters_5x5_reduce, 1, padding='same', use_bias=False)
|
| 65 |
+
self.branch3_reduce_bn = layers.BatchNormalization()
|
| 66 |
+
self.branch3_dw = layers.DepthwiseConv2D(5, padding='same', use_bias=False)
|
| 67 |
+
self.branch3_pw = layers.Conv2D(filters_5x5, 1, padding='same', use_bias=False)
|
| 68 |
+
self.branch3_bn = layers.BatchNormalization()
|
| 69 |
+
|
| 70 |
+
# Branch 4: MaxPool -> 1x1
|
| 71 |
+
self.branch4_pool = layers.MaxPooling2D(3, strides=1, padding='same')
|
| 72 |
+
self.branch4_conv = layers.Conv2D(filters_pool, 1, padding='same', use_bias=False)
|
| 73 |
+
self.branch4_bn = layers.BatchNormalization()
|
| 74 |
+
|
| 75 |
+
# Gating mechanism: learns to weight each branch
|
| 76 |
+
total_filters = filters_1x1 + filters_3x3 + filters_5x5 + filters_pool
|
| 77 |
+
self.gate_pool = layers.GlobalAveragePooling2D()
|
| 78 |
+
self.gate_dense1 = layers.Dense(self.num_branches * 4, activation='relu')
|
| 79 |
+
self.gate_dense2 = layers.Dense(self.num_branches, activation='sigmoid')
|
| 80 |
+
|
| 81 |
+
# Final activation
|
| 82 |
+
self.relu = layers.ReLU()
|
| 83 |
+
|
| 84 |
+
def call(self, x, training=False):
|
| 85 |
+
# Compute gate values (which branches to activate)
|
| 86 |
+
gate_input = self.gate_pool(x)
|
| 87 |
+
gate = self.gate_dense1(gate_input)
|
| 88 |
+
gate = self.gate_dense2(gate) # [batch, 4] sigmoid values
|
| 89 |
+
|
| 90 |
+
# Branch 1
|
| 91 |
+
b1 = self.branch1_conv(x)
|
| 92 |
+
b1 = self.branch1_bn(b1, training=training)
|
| 93 |
+
b1 = self.relu(b1)
|
| 94 |
+
|
| 95 |
+
# Branch 2
|
| 96 |
+
b2 = self.branch2_reduce(x)
|
| 97 |
+
b2 = self.branch2_reduce_bn(b2, training=training)
|
| 98 |
+
b2 = self.relu(b2)
|
| 99 |
+
b2 = self.branch2_conv(b2)
|
| 100 |
+
b2 = self.branch2_pw(b2)
|
| 101 |
+
b2 = self.branch2_bn(b2, training=training)
|
| 102 |
+
b2 = self.relu(b2)
|
| 103 |
+
|
| 104 |
+
# Branch 3
|
| 105 |
+
b3 = self.branch3_reduce(x)
|
| 106 |
+
b3 = self.branch3_reduce_bn(b3, training=training)
|
| 107 |
+
b3 = self.relu(b3)
|
| 108 |
+
b3 = self.branch3_dw(b3)
|
| 109 |
+
b3 = self.branch3_pw(b3)
|
| 110 |
+
b3 = self.branch3_bn(b3, training=training)
|
| 111 |
+
b3 = self.relu(b3)
|
| 112 |
+
|
| 113 |
+
# Branch 4
|
| 114 |
+
b4 = self.branch4_pool(x)
|
| 115 |
+
b4 = self.branch4_conv(b4)
|
| 116 |
+
b4 = self.branch4_bn(b4, training=training)
|
| 117 |
+
b4 = self.relu(b4)
|
| 118 |
+
|
| 119 |
+
# Apply gates: multiply each branch by its gate scalar
|
| 120 |
+
# gate[:, i] is a scalar per sample - reshape for broadcasting
|
| 121 |
+
g1 = tf.reshape(gate[:, 0], [-1, 1, 1, 1])
|
| 122 |
+
g2 = tf.reshape(gate[:, 1], [-1, 1, 1, 1])
|
| 123 |
+
g3 = tf.reshape(gate[:, 2], [-1, 1, 1, 1])
|
| 124 |
+
g4 = tf.reshape(gate[:, 3], [-1, 1, 1, 1])
|
| 125 |
+
|
| 126 |
+
b1 = b1 * g1
|
| 127 |
+
b2 = b2 * g2
|
| 128 |
+
b3 = b3 * g3
|
| 129 |
+
b4 = b4 * g4
|
| 130 |
+
|
| 131 |
+
# Concatenate gated branches
|
| 132 |
+
return tf.concat([b1, b2, b3, b4], axis=-1)
|
| 133 |
+
|
| 134 |
+
def get_config(self):
|
| 135 |
+
config = super().get_config()
|
| 136 |
+
config.update({
|
| 137 |
+
'filters_1x1': self.filters_1x1,
|
| 138 |
+
'filters_3x3_reduce': self.branch2_reduce.filters if hasattr(self.branch2_reduce, 'filters') else 0,
|
| 139 |
+
'filters_3x3': self.filters_3x3,
|
| 140 |
+
'filters_5x5_reduce': self.branch3_reduce.filters if hasattr(self.branch3_reduce, 'filters') else 0,
|
| 141 |
+
'filters_5x5': self.filters_5x5,
|
| 142 |
+
'filters_pool': self.filters_pool,
|
| 143 |
+
})
|
| 144 |
+
return config
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
class CoordinateAttention(layers.Layer):
|
| 148 |
+
"""
|
| 149 |
+
Coordinate Attention module (Hou et al. 2021).
|
| 150 |
+
Encodes spatial position info into channel attention for better localization.
|
| 151 |
+
Critical for gaze estimation where spatial position of iris matters.
|
| 152 |
+
"""
|
| 153 |
+
|
| 154 |
+
def __init__(self, reduction_ratio=4, **kwargs):
|
| 155 |
+
super().__init__(**kwargs)
|
| 156 |
+
self.reduction_ratio = reduction_ratio
|
| 157 |
+
|
| 158 |
+
def build(self, input_shape):
|
| 159 |
+
channels = input_shape[-1]
|
| 160 |
+
reduced_channels = max(channels // self.reduction_ratio, 8)
|
| 161 |
+
|
| 162 |
+
self.pool_h = layers.Lambda(lambda x: tf.reduce_mean(x, axis=2, keepdims=True))
|
| 163 |
+
self.pool_w = layers.Lambda(lambda x: tf.reduce_mean(x, axis=1, keepdims=True))
|
| 164 |
+
|
| 165 |
+
self.conv_reduce = layers.Conv2D(reduced_channels, 1, use_bias=False)
|
| 166 |
+
self.bn = layers.BatchNormalization()
|
| 167 |
+
self.relu = layers.ReLU()
|
| 168 |
+
|
| 169 |
+
self.conv_h = layers.Conv2D(channels, 1, activation='sigmoid')
|
| 170 |
+
self.conv_w = layers.Conv2D(channels, 1, activation='sigmoid')
|
| 171 |
+
|
| 172 |
+
super().build(input_shape)
|
| 173 |
+
|
| 174 |
+
def call(self, x, training=False):
|
| 175 |
+
# Pool along width (keep height)
|
| 176 |
+
h_att = self.pool_h(x) # [B, H, 1, C]
|
| 177 |
+
# Pool along height (keep width)
|
| 178 |
+
w_att = self.pool_w(x) # [B, 1, W, C]
|
| 179 |
+
|
| 180 |
+
# Transpose w_att to match h_att shape for concatenation
|
| 181 |
+
w_att_t = tf.transpose(w_att, perm=[0, 2, 1, 3]) # [B, W, 1, C]
|
| 182 |
+
|
| 183 |
+
# Concatenate and reduce
|
| 184 |
+
combined = tf.concat([h_att, w_att_t], axis=1) # [B, H+W, 1, C]
|
| 185 |
+
combined = self.conv_reduce(combined)
|
| 186 |
+
combined = self.bn(combined, training=training)
|
| 187 |
+
combined = self.relu(combined)
|
| 188 |
+
|
| 189 |
+
# Split back
|
| 190 |
+
h_len = tf.shape(h_att)[1]
|
| 191 |
+
w_len = tf.shape(w_att_t)[1]
|
| 192 |
+
|
| 193 |
+
h_out = combined[:, :h_len, :, :]
|
| 194 |
+
w_out = combined[:, h_len:, :, :]
|
| 195 |
+
|
| 196 |
+
# Generate attention maps
|
| 197 |
+
h_att_map = self.conv_h(h_out) # [B, H, 1, C]
|
| 198 |
+
w_att_map = self.conv_w(w_out) # [B, W, 1, C]
|
| 199 |
+
w_att_map = tf.transpose(w_att_map, perm=[0, 2, 1, 3]) # [B, 1, W, C]
|
| 200 |
+
|
| 201 |
+
# Apply attention
|
| 202 |
+
return x * h_att_map * w_att_map
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def build_gaze_inception_lite(input_shape=(64, 64, 3), num_outputs=2):
|
| 206 |
+
"""
|
| 207 |
+
Build the GazeInception-Lite model.
|
| 208 |
+
|
| 209 |
+
Architecture:
|
| 210 |
+
Input (64x64x3) -> Stem -> GatedInception1 -> GatedInception2 ->
|
| 211 |
+
CoordAttention -> GatedInception3 -> GlobalPool -> Dense -> (x, y)
|
| 212 |
+
|
| 213 |
+
Total: ~350K parameters
|
| 214 |
+
"""
|
| 215 |
+
inputs = layers.Input(shape=input_shape, name='eye_image')
|
| 216 |
+
|
| 217 |
+
# Stem: lightweight feature extraction
|
| 218 |
+
x = layers.Conv2D(32, 3, strides=2, padding='same', use_bias=False)(inputs) # 32x32
|
| 219 |
+
x = layers.BatchNormalization()(x)
|
| 220 |
+
x = layers.ReLU()(x)
|
| 221 |
+
x = layers.Conv2D(32, 3, padding='same', use_bias=False)(x) # 32x32
|
| 222 |
+
x = layers.BatchNormalization()(x)
|
| 223 |
+
x = layers.ReLU()(x)
|
| 224 |
+
|
| 225 |
+
# Gated Inception Block 1 (32x32 -> 16x16)
|
| 226 |
+
x = GatedInceptionBlock(
|
| 227 |
+
filters_1x1=16,
|
| 228 |
+
filters_3x3_reduce=16, filters_3x3=24,
|
| 229 |
+
filters_5x5_reduce=8, filters_5x5=12,
|
| 230 |
+
filters_pool=12,
|
| 231 |
+
name='gated_inception_1'
|
| 232 |
+
)(x) # output: 64 channels
|
| 233 |
+
x = layers.MaxPooling2D(2)(x) # 16x16
|
| 234 |
+
|
| 235 |
+
# Gated Inception Block 2 (16x16 -> 8x8)
|
| 236 |
+
x = GatedInceptionBlock(
|
| 237 |
+
filters_1x1=32,
|
| 238 |
+
filters_3x3_reduce=24, filters_3x3=48,
|
| 239 |
+
filters_5x5_reduce=12, filters_5x5=24,
|
| 240 |
+
filters_pool=24,
|
| 241 |
+
name='gated_inception_2'
|
| 242 |
+
)(x) # output: 128 channels
|
| 243 |
+
x = layers.MaxPooling2D(2)(x) # 8x8
|
| 244 |
+
|
| 245 |
+
# Coordinate Attention - encodes spatial position for gaze direction
|
| 246 |
+
x = CoordinateAttention(reduction_ratio=4, name='coord_attention')(x)
|
| 247 |
+
|
| 248 |
+
# Gated Inception Block 3 (8x8 -> 4x4)
|
| 249 |
+
x = GatedInceptionBlock(
|
| 250 |
+
filters_1x1=48,
|
| 251 |
+
filters_3x3_reduce=32, filters_3x3=64,
|
| 252 |
+
filters_5x5_reduce=16, filters_5x5=32,
|
| 253 |
+
filters_pool=32,
|
| 254 |
+
name='gated_inception_3'
|
| 255 |
+
)(x) # output: 176 channels
|
| 256 |
+
x = layers.MaxPooling2D(2)(x) # 4x4
|
| 257 |
+
|
| 258 |
+
# Global feature aggregation
|
| 259 |
+
x = layers.GlobalAveragePooling2D()(x)
|
| 260 |
+
|
| 261 |
+
# Regression head
|
| 262 |
+
x = layers.Dense(128, activation='relu')(x)
|
| 263 |
+
x = layers.Dropout(0.3)(x)
|
| 264 |
+
x = layers.Dense(64, activation='relu')(x)
|
| 265 |
+
x = layers.Dropout(0.2)(x)
|
| 266 |
+
|
| 267 |
+
# Output: (x, y) screen coordinates in [0, 1]
|
| 268 |
+
outputs = layers.Dense(num_outputs, activation='sigmoid', name='gaze_coords')(x)
|
| 269 |
+
|
| 270 |
+
model = Model(inputs=inputs, outputs=outputs, name='GazeInceptionLite')
|
| 271 |
+
return model
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def build_dual_eye_model(eye_shape=(64, 64, 3), face_shape=(64, 64, 3), num_outputs=2):
|
| 275 |
+
"""
|
| 276 |
+
Full model with dual eye inputs + face context.
|
| 277 |
+
|
| 278 |
+
This handles lazy eye by processing each eye independently through
|
| 279 |
+
shared-weight gated inception, then combining with face features.
|
| 280 |
+
Each eye gets its own gaze features, and the model learns to handle
|
| 281 |
+
asymmetric eye conditions (strabismus/amblyopia).
|
| 282 |
+
|
| 283 |
+
Inputs:
|
| 284 |
+
- left_eye: 64x64x3 crop
|
| 285 |
+
- right_eye: 64x64x3 crop
|
| 286 |
+
- face: 64x64x3 crop (provides head pose context)
|
| 287 |
+
|
| 288 |
+
Output:
|
| 289 |
+
- (x, y) normalized screen coordinates
|
| 290 |
+
"""
|
| 291 |
+
left_eye_input = layers.Input(shape=eye_shape, name='left_eye')
|
| 292 |
+
right_eye_input = layers.Input(shape=eye_shape, name='right_eye')
|
| 293 |
+
face_input = layers.Input(shape=face_shape, name='face')
|
| 294 |
+
|
| 295 |
+
# Shared eye feature extractor (gated inception backbone)
|
| 296 |
+
eye_backbone = build_gaze_inception_lite(input_shape=eye_shape, num_outputs=2)
|
| 297 |
+
# Get features from the GlobalAveragePooling layer (before dense head)
|
| 298 |
+
# Find the GlobalAveragePooling2D layer
|
| 299 |
+
gap_layer = None
|
| 300 |
+
for layer in eye_backbone.layers:
|
| 301 |
+
if isinstance(layer, layers.GlobalAveragePooling2D):
|
| 302 |
+
gap_layer = layer
|
| 303 |
+
eye_feature_extractor = Model(
|
| 304 |
+
inputs=eye_backbone.input,
|
| 305 |
+
outputs=gap_layer.output,
|
| 306 |
+
name='eye_feature_extractor'
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
# Extract features for each eye independently (shared weights)
|
| 310 |
+
left_features = eye_feature_extractor(left_eye_input) # [B, 176]
|
| 311 |
+
right_features = eye_feature_extractor(right_eye_input) # [B, 176]
|
| 312 |
+
|
| 313 |
+
# Lightweight face context extractor (head pose proxy)
|
| 314 |
+
f = layers.Conv2D(16, 3, strides=2, padding='same', activation='relu')(face_input)
|
| 315 |
+
f = layers.Conv2D(32, 3, strides=2, padding='same', activation='relu')(f)
|
| 316 |
+
f = layers.Conv2D(32, 3, strides=2, padding='same', activation='relu')(f)
|
| 317 |
+
f = layers.GlobalAveragePooling2D()(f)
|
| 318 |
+
face_features = layers.Dense(64, activation='relu')(f) # [B, 64]
|
| 319 |
+
|
| 320 |
+
# Combine: left_eye + right_eye + face
|
| 321 |
+
# The model learns eye asymmetry (lazy eye) because eyes are separate inputs
|
| 322 |
+
combined = layers.Concatenate()([left_features, right_features, face_features])
|
| 323 |
+
|
| 324 |
+
# Fusion head
|
| 325 |
+
x = layers.Dense(128, activation='relu')(combined)
|
| 326 |
+
x = layers.Dropout(0.3)(x)
|
| 327 |
+
x = layers.Dense(64, activation='relu')(x)
|
| 328 |
+
x = layers.Dropout(0.2)(x)
|
| 329 |
+
outputs = layers.Dense(num_outputs, activation='sigmoid', name='gaze_coords')(x)
|
| 330 |
+
|
| 331 |
+
model = Model(
|
| 332 |
+
inputs=[left_eye_input, right_eye_input, face_input],
|
| 333 |
+
outputs=outputs,
|
| 334 |
+
name='GazeInceptionLite_DualEye'
|
| 335 |
+
)
|
| 336 |
+
return model
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
if __name__ == '__main__':
|
| 340 |
+
# Test single eye model
|
| 341 |
+
model_single = build_gaze_inception_lite()
|
| 342 |
+
model_single.summary()
|
| 343 |
+
print(f"\nSingle eye model params: {model_single.count_params():,}")
|
| 344 |
+
|
| 345 |
+
# Test with random input
|
| 346 |
+
test_input = np.random.rand(2, 64, 64, 3).astype(np.float32)
|
| 347 |
+
output = model_single(test_input)
|
| 348 |
+
print(f"Output shape: {output.shape}")
|
| 349 |
+
print(f"Output values: {output.numpy()}")
|
| 350 |
+
|
| 351 |
+
print("\n" + "="*60)
|
| 352 |
+
|
| 353 |
+
# Test dual eye model
|
| 354 |
+
model_dual = build_dual_eye_model()
|
| 355 |
+
model_dual.summary()
|
| 356 |
+
print(f"\nDual eye model params: {model_dual.count_params():,}")
|
| 357 |
+
|
| 358 |
+
test_left = np.random.rand(2, 64, 64, 3).astype(np.float32)
|
| 359 |
+
test_right = np.random.rand(2, 64, 64, 3).astype(np.float32)
|
| 360 |
+
test_face = np.random.rand(2, 64, 64, 3).astype(np.float32)
|
| 361 |
+
output = model_dual([test_left, test_right, test_face])
|
| 362 |
+
print(f"Output shape: {output.shape}")
|
| 363 |
+
print(f"Output values: {output.numpy()}")
|