ccolas's picture
Upload 174 files
93c029f
raw
history blame
1.83 kB
import torch; torch.manual_seed(0)
import torch.nn as nn
import torch.nn.functional as F
import torch.utils
import torch.distributions
import matplotlib.pyplot as plt; plt.rcParams['figure.dpi'] = 200
device = 'cuda' if torch.cuda.is_available() else 'cpu'
def get_activation(activation):
if activation == 'tanh':
activ = F.tanh
elif activation == 'relu':
activ = F.relu
elif activation == 'mish':
activ = F.mish
elif activation == 'sigmoid':
activ = torch.sigmoid
elif activation == 'leakyrelu':
activ = F.leaky_relu
elif activation == 'exp':
activ = torch.exp
else:
raise ValueError
return activ
class SimpleNet(nn.Module):
def __init__(self, input_dim, hidden_dims, output_dim, activation, dropout, final_activ=None):
super(SimpleNet, self).__init__()
self.linears = nn.ModuleList()
self.dropouts = nn.ModuleList()
self.output_dim = output_dim
dims = [input_dim] + hidden_dims + [output_dim]
for d_in, d_out in zip(dims[:-1], dims[1:]):
self.linears.append(nn.Linear(d_in, d_out))
self.dropouts.append(nn.Dropout(dropout))
self.activation = get_activation(activation)
self.n_layers = len(self.linears)
self.layer_range = range(self.n_layers)
if final_activ != None:
self.final_activ = get_activation(final_activ)
self.use_final_activ = True
else:
self.use_final_activ = False
def forward(self, x):
for i_layer, layer, dropout in zip(self.layer_range, self.linears, self.dropouts):
x = layer(x)
if i_layer != self.n_layers - 1:
x = self.activation(dropout(x))
if self.use_final_activ: x = self.final_activ(x)
return x