Adapters
English
code
medical
dnnsdunca commited on
Commit
8be16c0
·
verified ·
1 Parent(s): bf26a42

Create Models/MoE_model.py

Browse files
Files changed (1) hide show
  1. Models/MoE_model.py +26 -0
Models/MoE_model.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from models.gating_network import GatingNetwork
4
+ from models.vision_expert import VisionExpert
5
+ from models.audio_expert import AudioExpert
6
+ from models.sensor_expert import SensorExpert
7
+
8
+ class MoEModel(nn.Module):
9
+ def __init__(self, input_dim, num_experts):
10
+ super(MoEModel, self).__init__()
11
+ self.gating_network = GatingNetwork(input_dim=input_dim, num_experts=num_experts)
12
+ self.experts = nn.ModuleList([VisionExpert(), AudioExpert(), SensorExpert()])
13
+ self.fc_final = nn.Linear(128, 10) # Assuming 10 possible actions
14
+
15
+ def forward(self, vision_input, audio_input, sensor_input):
16
+ vision_features = self.experts[0](vision_input)
17
+ audio_features = self.experts[1](audio_input)
18
+ sensor_features = self.experts[2](sensor_input)
19
+
20
+ combined_features = torch.cat((vision_features, audio_features, sensor_features), dim=1)
21
+ gating_weights = self.gating_network(combined_features)
22
+
23
+ expert_outputs = torch.stack([expert(combined_features) for expert in self.experts], dim=1)
24
+ final_output = torch.einsum('ij,ijk->ik', gating_weights, expert_outputs)
25
+
26
+ return self.fc_final(final_output)