pcsr_carn / models /mlp.py
3587jjh's picture
Upload 10 files
61522a1 verified
raw
history blame
862 Bytes
import torch
import torch.nn as nn
import torch.nn.functional as F
from models import register
@register('mlp')
class MLP(nn.Module):
def __init__(self, in_dim, out_dim, hidden_list, residual=False):
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.hidden_list = hidden_list
self.residual = residual
if residual:
self.convert = nn.Linear(in_dim, out_dim)
layers = []
lastv = in_dim
for hidden in hidden_list:
layers.append(nn.Linear(lastv, hidden))
layers.append(nn.ReLU())
lastv = hidden
layers.append(nn.Linear(lastv, out_dim))
self.layers = nn.Sequential(*layers)
def forward(self, x):
y = self.layers(x)
if self.residual:
y = y + self.convert(x)
return y