File size: 1,938 Bytes
491eded |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
import torch
import torch.nn as nn
class VanillaMLP(nn.Module):
def __init__(self, input_dim, output_dim, out_activation, n_hidden_layers=4, n_neurons=64, activation="ReLU"):
super().__init__()
self.n_neurons = n_neurons
self.n_hidden_layers = n_hidden_layers
self.activation = activation
self.out_activation = out_activation
layers = [
self.make_linear(input_dim, self.n_neurons, is_first=True, is_last=False),
self.make_activation(),
]
for i in range(self.n_hidden_layers - 1):
layers += [
self.make_linear(
self.n_neurons, self.n_neurons, is_first=False, is_last=False
),
self.make_activation(),
]
layers += [
self.make_linear(self.n_neurons, output_dim, is_first=False, is_last=True)
]
if self.out_activation == "sigmoid":
layers += [nn.Sigmoid()]
elif self.out_activation == "tanh":
layers += [nn.Tanh()]
elif self.out_activation == "hardtanh":
layers += [nn.Hardtanh()]
elif self.out_activation == "GELU":
layers += [nn.GELU()]
elif self.out_activation == "RELU":
layers += [nn.ReLU()]
else:
raise NotImplementedError
self.layers = nn.Sequential(*layers)
def forward(self, x, split_size=100000):
with torch.cuda.amp.autocast(enabled=False):
out = self.layers(x)
return out
def make_linear(self, dim_in, dim_out, is_first, is_last):
layer = nn.Linear(dim_in, dim_out, bias=False)
return layer
def make_activation(self):
if self.activation == "ReLU":
return nn.ReLU(inplace=True)
elif self.activation == "GELU":
return nn.GELU()
else:
raise NotImplementedError |