import torch | |
from torch import nn | |
import torch.nn.functional as F | |
class MLPProberBase(nn.Module): | |
def __init__(self, d=768, num_outputs=87): | |
super().__init__() | |
self.hidden_layer_sizes = [512, ] # eval(self.cfg.hidden_layer_sizes) | |
self.num_layers = len(self.hidden_layer_sizes) | |
for i, ld in enumerate(self.hidden_layer_sizes): | |
setattr(self, f"hidden_{i}", nn.Linear(d, ld)) | |
d = ld | |
self.output = nn.Linear(d, num_outputs) | |
def forward(self, x): | |
for i in range(self.num_layers): | |
x = getattr(self, f"hidden_{i}")(x) | |
# x = self.dropout(x) | |
x = F.relu(x) | |
output = self.output(x) | |
return output |