Adapters
English
code
medical
dnnsdunca commited on
Commit
cd5974b
1 Parent(s): fc01c7b

Create gating_network.py

Browse files
Files changed (1) hide show
  1. gating_network.py +13 -0
gating_network.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class GatingNetwork(nn.Module):
5
+ def __init__(self, input_dim, num_experts):
6
+ super(GatingNetwork, self).__init__()
7
+ self.fc1 = nn.Linear(input_dim, 128)
8
+ self.fc2 = nn.Linear(128, num_experts)
9
+
10
+ def forward(self, x):
11
+ x = torch.relu(self.fc1(x))
12
+ x = torch.softmax(self.fc2(x), dim=1)
13
+ return x