import torch import torch.nn as nn import torch.nn.functional as F from models import register @register('mlp') class MLP(nn.Module): def __init__(self, in_dim, out_dim, hidden_list, residual=False): super().__init__() self.in_dim = in_dim self.out_dim = out_dim self.hidden_list = hidden_list self.residual = residual if residual: self.convert = nn.Linear(in_dim, out_dim) layers = [] lastv = in_dim for hidden in hidden_list: layers.append(nn.Linear(lastv, hidden)) layers.append(nn.ReLU()) lastv = hidden layers.append(nn.Linear(lastv, out_dim)) self.layers = nn.Sequential(*layers) def forward(self, x): y = self.layers(x) if self.residual: y = y + self.convert(x) return y