File size: 1,513 Bytes
ebfd652
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from transformers import PreTrainedModel, BertModel
import torch
from .configuration_siamese import SiameseConfig

checkpoint = 'cointegrated/rubert-tiny'

class Lambda(torch.nn.Module):
    def __init__(self, lambd):
        super().__init__()
        self.lambd = lambd
    
    def forward(self, x):
         return self.lambd(x)


class SiameseNN(torch.nn.Module):
    def __init__(self):
        super(SiameseNN, self).__init__()
        l1_norm = lambda x: 1 - torch.abs(x[0] - x[1])
        self.encoder = BertModel.from_pretrained(checkpoint)
        self.merged = Lambda(l1_norm)
        self.fc1 = torch.nn.Linear(312, 2)
        self.softmax = torch.nn.Softmax()

    
    def forward(self, x):
        first_encoded = self.encoder(**x[0]).pooler_output
        second_encoded = self.encoder(**x[1]).pooler_output
        l1_distance = self.merged([first_encoded, second_encoded])
        fc1 = self.fc1(l1_distance)
        return self.softmax(fc1)

second_model = SiameseNN()
second_model.load_state_dict(torch.load('siamese_state'))

class SiamseNNModel(PreTrainedModel):
    config_class = SiameseConfig
    def __init__(self, config):
        super().__init__(config)
        self.model = second_model

    
    def forward(self, tensor, labels=None):
        logits = self.model(tensor)
        if labels is not None:
            loss_fn = torch.nn.CrossEntropyLoss()
            loss = loss_fn(logits, labels)
            return {'loss': loss, 'logits': logits}
        return {'logits': logits}