from transformers import PretrainedConfig, PreTrainedModel, BertModel, BertConfig from .configuration_bert import SimBertConfig from torch import nn class SimBertModel(PreTrainedModel): """ SimBert Model """ config_class = SimBertConfig def __init__( self, config: PretrainedConfig ) -> None: super().__init__(config) self.bert = BertModel(config=config, add_pooling_layer=True) self.fc = nn.Linear(config.hidden_size, 2) # self.loss_fct = nn.CrossEntropyLoss() self.loss_fct = nn.MSELoss() self.softmax = nn.Softmax(dim=1) def forward( self, input_ids, token_type_ids, attention_mask, labels=None ): outputs = self.bert( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids ) pooled_output = outputs.pooler_output logits = self.fc(pooled_output) logits = self.softmax(logits)[:,1] if labels is not None: loss = self.loss_fct(logits.view(-1), labels.view(-1)) return loss, logits return None, logits class CosSimBertModel(PreTrainedModel): """ CosSimBert Model """ config_class = SimBertConfig def __init__( self, config: PretrainedConfig ) -> None: super().__init__(config) self.bert = BertModel(config=config, add_pooling_layer=True) self.loss_fct = nn.MSELoss() self.softmax = nn.Softmax(dim=1) def forward( self, input_ids, token_type_ids, attention_mask, labels=None ): seq_length = input_ids.size(-1) a = { "input_ids": input_ids[:,:seq_length//2], "token_type_ids": token_type_ids[:,:seq_length//2], "attention_mask": attention_mask[:,:seq_length//2] } b = { "input_ids": input_ids[:,seq_length//2:], "token_type_ids": token_type_ids[:,seq_length//2:], "attention_mask": attention_mask[:,seq_length//2:] } outputs_a = self.bert(**a) outputs_b = self.bert(**b) pooled_a_output = outputs_a.pooler_output pooled_b_output = outputs_b.pooler_output logits = nn.functional.cosine_similarity(pooled_a_output, pooled_b_output) if labels is not None: loss = self.loss_fct(logits.view(-1), labels.view(-1)) return loss, logits return None, logits def encode( self, input_ids, token_type_ids, attention_mask, ): outputs = self.bert( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids ) pooled_output = outputs.pooler_output return pooled_output