Epsilon617
add genre prediction head
92cd759
raw history blame
No virus
722 Bytes
import torch
from torch import nn
import torch.nn.functional as F
class MLPProberBase(nn.Module):
def __init__(self, d=768, num_outputs=87):
super().__init__()
self.hidden_layer_sizes = [512, ] # eval(self.cfg.hidden_layer_sizes)
self.num_layers = len(self.hidden_layer_sizes)
for i, ld in enumerate(self.hidden_layer_sizes):
setattr(self, f"hidden_{i}", nn.Linear(d, ld))
d = ld
self.output = nn.Linear(d, num_outputs)
def forward(self, x):
for i in range(self.num_layers):
x = getattr(self, f"hidden_{i}")(x)
# x = self.dropout(x)
x = F.relu(x)
output = self.output(x)
return output