File size: 384 Bytes
cd5974b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 |
import torch
import torch.nn as nn
class GatingNetwork(nn.Module):
def __init__(self, input_dim, num_experts):
super(GatingNetwork, self).__init__()
self.fc1 = nn.Linear(input_dim, 128)
self.fc2 = nn.Linear(128, num_experts)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = torch.softmax(self.fc2(x), dim=1)
return x
|