File size: 672 Bytes
d15a7ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
from transformers import AutoModel
from torch import nn

class BERTClassifier(nn.Module):
    def __init__(self, bert_path="cointegrated/rubert-tiny2"):
        super().__init__()
        self.bert = AutoModel.from_pretrained(bert_path)
        for param in self.bert.parameters():
            param.requires_grad = False
        self.linear = nn.Sequential(
            nn.Linear(312, 150),
            nn.Dropout(0.1),
            nn.ReLU(),
            nn.Linear(150, 1),
            nn.Sigmoid()
        )
        
    def forward(self, x, masks):
        bert_out = self.bert(x, attention_mask=masks)[0][:, 0, :]
        out = self.linear(bert_out)
        return out