File size: 1,605 Bytes
96320a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import torch
import torch.nn as nn
from torch.nn import TransformerEncoder, TransformerEncoderLayer

class ZIAModel(nn.Module):
    def __init__(self, n_intents=10, d_model=128, nhead=8, num_layers=6, dim_feedforward=512):
        super(ZIAModel, self).__init__()
        self.d_model = d_model
        
        # Modality-specific encoders
        self.gaze_encoder = nn.Linear(2, d_model)
        self.hr_encoder = nn.Linear(1, d_model)
        self.eeg_encoder = nn.Linear(4, d_model)
        self.context_encoder = nn.Linear(32 + 3 + 20, d_model)  # Time (32) + Location (3) + Usage (20)
        
        # Transformer
        encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout=0.1, batch_first=True)
        self.transformer = TransformerEncoder(encoder_layer, num_layers)
        
        # Output layer
        self.fc = nn.Linear(d_model, n_intents)
        
    def forward(self, gaze, hr, eeg, context):
        # Encode modalities
        gaze_emb = self.gaze_encoder(gaze)  # [batch, seq, d_model]
        hr_emb = self.hr_encoder(hr.unsqueeze(-1))
        eeg_emb = self.eeg_encoder(eeg)
        context_emb = self.context_encoder(context)
        
        # Fuse modalities
        fused = (gaze_emb + hr_emb + eeg_emb + context_emb) / 4  # Simple averaging
        
        # Transformer
        output = self.transformer(fused)
        output = output.mean(dim=1)  # Pool over sequence
        
        # Predict intent
        logits = self.fc(output)
        return logits

# Example usage
if __name__ == "__main__":
    model = ZIAModel()
    print(model)