|
from transformers import AutoModel |
|
from torch import nn |
|
|
|
class BERTClassifier(nn.Module): |
|
def __init__(self, bert_path="cointegrated/rubert-tiny2"): |
|
super().__init__() |
|
self.bert = AutoModel.from_pretrained(bert_path) |
|
for param in self.bert.parameters(): |
|
param.requires_grad = False |
|
self.linear = nn.Sequential( |
|
nn.Linear(312, 150), |
|
nn.Dropout(0.1), |
|
nn.ReLU(), |
|
nn.Linear(150, 1), |
|
nn.Sigmoid() |
|
) |
|
|
|
def forward(self, x, masks): |
|
bert_out = self.bert(x, attention_mask=masks)[0][:, 0, :] |
|
out = self.linear(bert_out) |
|
return out |