Spaces:
Sleeping
Sleeping
import torch.nn as nn | |
class MLP(nn.Module): | |
def __init__( | |
self, | |
input_size, | |
hidden_size, | |
output_size, | |
num_hidden_layers=1, | |
bias=False, | |
drop_module=None, | |
): | |
super(MLP, self).__init__() | |
self.layer_list = [] | |
self.activation = nn.ReLU() | |
self.drop_module = drop_module | |
self.num_hidden_layers = num_hidden_layers | |
cur_output_size = input_size | |
for i in range(num_hidden_layers): | |
self.layer_list.append(nn.Linear(cur_output_size, hidden_size, bias=bias)) | |
self.layer_list.append(self.activation) | |
if self.drop_module is not None: | |
self.layer_list.append(self.drop_module) | |
cur_output_size = hidden_size | |
self.layer_list.append(nn.Linear(cur_output_size, output_size, bias=bias)) | |
self.fc_layers = nn.Sequential(*self.layer_list) | |
def forward(self, mlp_input): | |
return self.fc_layers(mlp_input) | |