File size: 688 Bytes
701d7dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
from transformers import AutoTokenizer, AutoModel
import torch


class CrossEncoderBert(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.model = AutoModel.from_pretrained('distilbert-base-uncased')
        self.tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')
        self.tokenizer.add_tokens(["[Cont_token]"], special_tokens=True)
        self.linear = torch.nn.Linear(self.model.config.hidden_size, 1)

    def forward(self, input_ids, attention_mask):
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.last_hidden_state[:, 0]
        return self.linear(pooled_output)