import torch import torch.nn as nn from torch.nn import TransformerEncoder, TransformerEncoderLayer class ZIAModel(nn.Module): def __init__(self, n_intents=10, d_model=128, nhead=8, num_layers=6, dim_feedforward=512): super(ZIAModel, self).__init__() self.d_model = d_model # Modality-specific encoders self.gaze_encoder = nn.Linear(2, d_model) self.hr_encoder = nn.Linear(1, d_model) self.eeg_encoder = nn.Linear(4, d_model) self.context_encoder = nn.Linear(32 + 3 + 20, d_model) # Time (32) + Location (3) + Usage (20) # Transformer encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout=0.1, batch_first=True) self.transformer = TransformerEncoder(encoder_layer, num_layers) # Output layer self.fc = nn.Linear(d_model, n_intents) def forward(self, gaze, hr, eeg, context): # Encode modalities gaze_emb = self.gaze_encoder(gaze) # [batch, seq, d_model] hr_emb = self.hr_encoder(hr.unsqueeze(-1)) eeg_emb = self.eeg_encoder(eeg) context_emb = self.context_encoder(context) # Fuse modalities fused = (gaze_emb + hr_emb + eeg_emb + context_emb) / 4 # Simple averaging # Transformer output = self.transformer(fused) output = output.mean(dim=1) # Pool over sequence # Predict intent logits = self.fc(output) return logits # Example usage if __name__ == "__main__": model = ZIAModel() print(model)