File size: 722 Bytes
92cd759 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 |
import torch
from torch import nn
import torch.nn.functional as F
class MLPProberBase(nn.Module):
def __init__(self, d=768, num_outputs=87):
super().__init__()
self.hidden_layer_sizes = [512, ] # eval(self.cfg.hidden_layer_sizes)
self.num_layers = len(self.hidden_layer_sizes)
for i, ld in enumerate(self.hidden_layer_sizes):
setattr(self, f"hidden_{i}", nn.Linear(d, ld))
d = ld
self.output = nn.Linear(d, num_outputs)
def forward(self, x):
for i in range(self.num_layers):
x = getattr(self, f"hidden_{i}")(x)
# x = self.dropout(x)
x = F.relu(x)
output = self.output(x)
return output |