File size: 862 Bytes
61522a1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from models import register
@register('mlp')
class MLP(nn.Module):
def __init__(self, in_dim, out_dim, hidden_list, residual=False):
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.hidden_list = hidden_list
self.residual = residual
if residual:
self.convert = nn.Linear(in_dim, out_dim)
layers = []
lastv = in_dim
for hidden in hidden_list:
layers.append(nn.Linear(lastv, hidden))
layers.append(nn.ReLU())
lastv = hidden
layers.append(nn.Linear(lastv, out_dim))
self.layers = nn.Sequential(*layers)
def forward(self, x):
y = self.layers(x)
if self.residual:
y = y + self.convert(x)
return y |