Li
add transformers
455a40f
raw
history blame
651 Bytes
from torch import nn
class ClassificationHead(nn.Module):
"""Classification Head for transformer encoders"""
def __init__(self, class_size, embed_size):
super().__init__()
self.class_size = class_size
self.embed_size = embed_size
# self.mlp1 = nn.Linear(embed_size, embed_size)
# self.mlp2 = (nn.Linear(embed_size, class_size))
self.mlp = nn.Linear(embed_size, class_size)
def forward(self, hidden_state):
# hidden_state = nn.functional.relu(self.mlp1(hidden_state))
# hidden_state = self.mlp2(hidden_state)
logits = self.mlp(hidden_state)
return logits