from transformers import PreTrainedModel, BertModel import torch from .configuration_siamese import SiameseConfig checkpoint = 'cointegrated/rubert-tiny' class Lambda(torch.nn.Module): def __init__(self, lambd): super().__init__() self.lambd = lambd def forward(self, x): return self.lambd(x) class SiameseNN(torch.nn.Module): def __init__(self): super(SiameseNN, self).__init__() l1_norm = lambda x: 1 - torch.abs(x[0] - x[1]) self.encoder = BertModel.from_pretrained(checkpoint) self.merged = Lambda(l1_norm) self.fc1 = torch.nn.Linear(312, 2) self.softmax = torch.nn.Softmax() def forward(self, x): first_encoded = self.encoder(**x[0]).pooler_output second_encoded = self.encoder(**x[1]).pooler_output l1_distance = self.merged([first_encoded, second_encoded]) fc1 = self.fc1(l1_distance) return self.softmax(fc1) second_model = SiameseNN() second_model.load_state_dict(torch.load('siamese_state')) class SiamseNNModel(PreTrainedModel): config_class = SiameseConfig def __init__(self, config): super().__init__(config) self.model = second_model def forward(self, tensor, labels=None): logits = self.model(tensor) if labels is not None: loss_fn = torch.nn.CrossEntropyLoss() loss = loss_fn(logits, labels) return {'loss': loss, 'logits': logits} return {'logits': logits}