Adapters
English
code
medical
UANN / gating_network.py
dnnsdunca's picture
Create gating_network.py
cd5974b verified
raw
history blame
384 Bytes
import torch
import torch.nn as nn
class GatingNetwork(nn.Module):
def __init__(self, input_dim, num_experts):
super(GatingNetwork, self).__init__()
self.fc1 = nn.Linear(input_dim, 128)
self.fc2 = nn.Linear(128, num_experts)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = torch.softmax(self.fc2(x), dim=1)
return x